upload files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- RepCodec/.gitignore +160 -0
- RepCodec/LICENSE +428 -0
- RepCodec/README.md +273 -0
- RepCodec/examples/data2vec_audio.py +541 -0
- RepCodec/examples/data2vec_feature_reader.py +87 -0
- RepCodec/examples/dump_feature.py +142 -0
- RepCodec/examples/feature_utils.py +70 -0
- RepCodec/examples/hubert_feature_reader.py +64 -0
- RepCodec/examples/tokens/data2vec_base_l6_dev-clean.tokens +0 -0
- RepCodec/examples/tokens/data2vec_large_l18_dev-clean.tokens +0 -0
- RepCodec/examples/tokens/hubert_base_l9_dev-clean.tokens +0 -0
- RepCodec/examples/tokens/hubert_large_l18_dev-clean.tokens +0 -0
- RepCodec/examples/tokens/whisper_large_l32_dev-clean.tokens +0 -0
- RepCodec/examples/tokens/whisper_medium_l24_dev-clean.tokens +0 -0
- RepCodec/examples/whisper_feature_reader.py +110 -0
- RepCodec/examples/whisper_model.py +58 -0
- RepCodec/repcodec/RepCodec.py +84 -0
- RepCodec/repcodec/configs/repcodec_dim1024.yaml +18 -0
- RepCodec/repcodec/configs/repcodec_dim1280.yaml +18 -0
- RepCodec/repcodec/configs/repcodec_dim768.yaml +18 -0
- RepCodec/repcodec/layers/conv_layer.py +95 -0
- RepCodec/repcodec/layers/vq_module.py +155 -0
- RepCodec/repcodec/modules/decoder.py +109 -0
- RepCodec/repcodec/modules/encoder.py +89 -0
- RepCodec/repcodec/modules/projector.py +32 -0
- RepCodec/repcodec/modules/quantizer.py +46 -0
- RepCodec/repcodec/modules/residual_unit.py +39 -0
- RepCodec/repcodec/tokenize.py +212 -0
- RepCodec/setup.py +31 -0
- RepCodec/train.py +228 -0
- RepCodec/train_configs/ex_dim768_mse.yaml +74 -0
- RepCodec/trainer/autoencoder.py +287 -0
- __pycache__/post_process_audio.cpython-310.pyc +0 -0
- __pycache__/vocoder.cpython-310.pyc +0 -0
- decoders/config.yaml +15 -0
- decoders/decoder_131000.pth +3 -0
- decoders/decoder_151000.pth +3 -0
- descriptaudiocodec/dac/__init__.py +16 -0
- descriptaudiocodec/dac/__main__.py +36 -0
- descriptaudiocodec/dac/__pycache__/__init__.cpython-310.pyc +0 -0
- descriptaudiocodec/dac/__pycache__/__init__.cpython-38.pyc +0 -0
- descriptaudiocodec/dac/__pycache__/__init__.cpython-39.pyc +0 -0
- descriptaudiocodec/dac/compare/__init__.py +0 -0
- descriptaudiocodec/dac/compare/encodec.py +54 -0
- descriptaudiocodec/dac/model/__init__.py +4 -0
- descriptaudiocodec/dac/model/__pycache__/__init__.cpython-310.pyc +0 -0
- descriptaudiocodec/dac/model/__pycache__/__init__.cpython-39.pyc +0 -0
- descriptaudiocodec/dac/model/__pycache__/base.cpython-310.pyc +0 -0
- descriptaudiocodec/dac/model/__pycache__/base.cpython-39.pyc +0 -0
- descriptaudiocodec/dac/model/__pycache__/dac.cpython-310.pyc +0 -0
RepCodec/.gitignore
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
.idea/
|
RepCodec/LICENSE
ADDED
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) ByteDance, Inc. and its affiliates.
|
4 |
+
Copyright (c) Chutong Meng
|
5 |
+
|
6 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
7 |
+
of this software and associated documentation files (the "Software"), to deal
|
8 |
+
in the Software without restriction, including without limitation the rights
|
9 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
10 |
+
copies of the Software, and to permit persons to whom the Software is
|
11 |
+
furnished to do so, subject to the following conditions:
|
12 |
+
|
13 |
+
The above copyright notice and this permission notice shall be included in all
|
14 |
+
copies or substantial portions of the Software.
|
15 |
+
|
16 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
17 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
18 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
19 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
20 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
21 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
22 |
+
SOFTWARE.
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
Attribution-NonCommercial 4.0 International
|
31 |
+
|
32 |
+
=======================================================================
|
33 |
+
|
34 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
35 |
+
does not provide legal services or legal advice. Distribution of
|
36 |
+
Creative Commons public licenses does not create a lawyer-client or
|
37 |
+
other relationship. Creative Commons makes its licenses and related
|
38 |
+
information available on an "as-is" basis. Creative Commons gives no
|
39 |
+
warranties regarding its licenses, any material licensed under their
|
40 |
+
terms and conditions, or any related information. Creative Commons
|
41 |
+
disclaims all liability for damages resulting from their use to the
|
42 |
+
fullest extent possible.
|
43 |
+
|
44 |
+
Using Creative Commons Public Licenses
|
45 |
+
|
46 |
+
Creative Commons public licenses provide a standard set of terms and
|
47 |
+
conditions that creators and other rights holders may use to share
|
48 |
+
original works of authorship and other material subject to copyright
|
49 |
+
and certain other rights specified in the public license below. The
|
50 |
+
following considerations are for informational purposes only, are not
|
51 |
+
exhaustive, and do not form part of our licenses.
|
52 |
+
|
53 |
+
Considerations for licensors: Our public licenses are
|
54 |
+
intended for use by those authorized to give the public
|
55 |
+
permission to use material in ways otherwise restricted by
|
56 |
+
copyright and certain other rights. Our licenses are
|
57 |
+
irrevocable. Licensors should read and understand the terms
|
58 |
+
and conditions of the license they choose before applying it.
|
59 |
+
Licensors should also secure all rights necessary before
|
60 |
+
applying our licenses so that the public can reuse the
|
61 |
+
material as expected. Licensors should clearly mark any
|
62 |
+
material not subject to the license. This includes other CC-
|
63 |
+
licensed material, or material used under an exception or
|
64 |
+
limitation to copyright. More considerations for licensors:
|
65 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
66 |
+
|
67 |
+
Considerations for the public: By using one of our public
|
68 |
+
licenses, a licensor grants the public permission to use the
|
69 |
+
licensed material under specified terms and conditions. If
|
70 |
+
the licensor's permission is not necessary for any reason--for
|
71 |
+
example, because of any applicable exception or limitation to
|
72 |
+
copyright--then that use is not regulated by the license. Our
|
73 |
+
licenses grant only permissions under copyright and certain
|
74 |
+
other rights that a licensor has authority to grant. Use of
|
75 |
+
the licensed material may still be restricted for other
|
76 |
+
reasons, including because others have copyright or other
|
77 |
+
rights in the material. A licensor may make special requests,
|
78 |
+
such as asking that all changes be marked or described.
|
79 |
+
Although not required by our licenses, you are encouraged to
|
80 |
+
respect those requests where reasonable. More_considerations
|
81 |
+
for the public:
|
82 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
83 |
+
|
84 |
+
=======================================================================
|
85 |
+
|
86 |
+
Creative Commons Attribution-NonCommercial 4.0 International Public
|
87 |
+
License
|
88 |
+
|
89 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
90 |
+
to be bound by the terms and conditions of this Creative Commons
|
91 |
+
Attribution-NonCommercial 4.0 International Public License ("Public
|
92 |
+
License"). To the extent this Public License may be interpreted as a
|
93 |
+
contract, You are granted the Licensed Rights in consideration of Your
|
94 |
+
acceptance of these terms and conditions, and the Licensor grants You
|
95 |
+
such rights in consideration of benefits the Licensor receives from
|
96 |
+
making the Licensed Material available under these terms and
|
97 |
+
conditions.
|
98 |
+
|
99 |
+
Section 1 -- Definitions.
|
100 |
+
|
101 |
+
a. Adapted Material means material subject to Copyright and Similar
|
102 |
+
Rights that is derived from or based upon the Licensed Material
|
103 |
+
and in which the Licensed Material is translated, altered,
|
104 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
105 |
+
permission under the Copyright and Similar Rights held by the
|
106 |
+
Licensor. For purposes of this Public License, where the Licensed
|
107 |
+
Material is a musical work, performance, or sound recording,
|
108 |
+
Adapted Material is always produced where the Licensed Material is
|
109 |
+
synched in timed relation with a moving image.
|
110 |
+
|
111 |
+
b. Adapter's License means the license You apply to Your Copyright
|
112 |
+
and Similar Rights in Your contributions to Adapted Material in
|
113 |
+
accordance with the terms and conditions of this Public License.
|
114 |
+
|
115 |
+
c. Copyright and Similar Rights means copyright and/or similar rights
|
116 |
+
closely related to copyright including, without limitation,
|
117 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
118 |
+
Rights, without regard to how the rights are labeled or
|
119 |
+
categorized. For purposes of this Public License, the rights
|
120 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
121 |
+
Rights.
|
122 |
+
d. Effective Technological Measures means those measures that, in the
|
123 |
+
absence of proper authority, may not be circumvented under laws
|
124 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
125 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
126 |
+
agreements.
|
127 |
+
|
128 |
+
e. Exceptions and Limitations means fair use, fair dealing, and/or
|
129 |
+
any other exception or limitation to Copyright and Similar Rights
|
130 |
+
that applies to Your use of the Licensed Material.
|
131 |
+
|
132 |
+
f. Licensed Material means the artistic or literary work, database,
|
133 |
+
or other material to which the Licensor applied this Public
|
134 |
+
License.
|
135 |
+
|
136 |
+
g. Licensed Rights means the rights granted to You subject to the
|
137 |
+
terms and conditions of this Public License, which are limited to
|
138 |
+
all Copyright and Similar Rights that apply to Your use of the
|
139 |
+
Licensed Material and that the Licensor has authority to license.
|
140 |
+
|
141 |
+
h. Licensor means the individual(s) or entity(ies) granting rights
|
142 |
+
under this Public License.
|
143 |
+
|
144 |
+
i. NonCommercial means not primarily intended for or directed towards
|
145 |
+
commercial advantage or monetary compensation. For purposes of
|
146 |
+
this Public License, the exchange of the Licensed Material for
|
147 |
+
other material subject to Copyright and Similar Rights by digital
|
148 |
+
file-sharing or similar means is NonCommercial provided there is
|
149 |
+
no payment of monetary compensation in connection with the
|
150 |
+
exchange.
|
151 |
+
|
152 |
+
j. Share means to provide material to the public by any means or
|
153 |
+
process that requires permission under the Licensed Rights, such
|
154 |
+
as reproduction, public display, public performance, distribution,
|
155 |
+
dissemination, communication, or importation, and to make material
|
156 |
+
available to the public including in ways that members of the
|
157 |
+
public may access the material from a place and at a time
|
158 |
+
individually chosen by them.
|
159 |
+
|
160 |
+
k. Sui Generis Database Rights means rights other than copyright
|
161 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
162 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
163 |
+
as amended and/or succeeded, as well as other essentially
|
164 |
+
equivalent rights anywhere in the world.
|
165 |
+
|
166 |
+
l. You means the individual or entity exercising the Licensed Rights
|
167 |
+
under this Public License. Your has a corresponding meaning.
|
168 |
+
|
169 |
+
Section 2 -- Scope.
|
170 |
+
|
171 |
+
a. License grant.
|
172 |
+
|
173 |
+
1. Subject to the terms and conditions of this Public License,
|
174 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
175 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
176 |
+
exercise the Licensed Rights in the Licensed Material to:
|
177 |
+
|
178 |
+
a. reproduce and Share the Licensed Material, in whole or
|
179 |
+
in part, for NonCommercial purposes only; and
|
180 |
+
|
181 |
+
b. produce, reproduce, and Share Adapted Material for
|
182 |
+
NonCommercial purposes only.
|
183 |
+
|
184 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
185 |
+
Exceptions and Limitations apply to Your use, this Public
|
186 |
+
License does not apply, and You do not need to comply with
|
187 |
+
its terms and conditions.
|
188 |
+
|
189 |
+
3. Term. The term of this Public License is specified in Section
|
190 |
+
6(a).
|
191 |
+
|
192 |
+
4. Media and formats; technical modifications allowed. The
|
193 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
194 |
+
all media and formats whether now known or hereafter created,
|
195 |
+
and to make technical modifications necessary to do so. The
|
196 |
+
Licensor waives and/or agrees not to assert any right or
|
197 |
+
authority to forbid You from making technical modifications
|
198 |
+
necessary to exercise the Licensed Rights, including
|
199 |
+
technical modifications necessary to circumvent Effective
|
200 |
+
Technological Measures. For purposes of this Public License,
|
201 |
+
simply making modifications authorized by this Section 2(a)
|
202 |
+
(4) never produces Adapted Material.
|
203 |
+
|
204 |
+
5. Downstream recipients.
|
205 |
+
|
206 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
207 |
+
recipient of the Licensed Material automatically
|
208 |
+
receives an offer from the Licensor to exercise the
|
209 |
+
Licensed Rights under the terms and conditions of this
|
210 |
+
Public License.
|
211 |
+
|
212 |
+
b. No downstream restrictions. You may not offer or impose
|
213 |
+
any additional or different terms or conditions on, or
|
214 |
+
apply any Effective Technological Measures to, the
|
215 |
+
Licensed Material if doing so restricts exercise of the
|
216 |
+
Licensed Rights by any recipient of the Licensed
|
217 |
+
Material.
|
218 |
+
|
219 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
220 |
+
may be construed as permission to assert or imply that You
|
221 |
+
are, or that Your use of the Licensed Material is, connected
|
222 |
+
with, or sponsored, endorsed, or granted official status by,
|
223 |
+
the Licensor or others designated to receive attribution as
|
224 |
+
provided in Section 3(a)(1)(A)(i).
|
225 |
+
|
226 |
+
b. Other rights.
|
227 |
+
|
228 |
+
1. Moral rights, such as the right of integrity, are not
|
229 |
+
licensed under this Public License, nor are publicity,
|
230 |
+
privacy, and/or other similar personality rights; however, to
|
231 |
+
the extent possible, the Licensor waives and/or agrees not to
|
232 |
+
assert any such rights held by the Licensor to the limited
|
233 |
+
extent necessary to allow You to exercise the Licensed
|
234 |
+
Rights, but not otherwise.
|
235 |
+
|
236 |
+
2. Patent and trademark rights are not licensed under this
|
237 |
+
Public License.
|
238 |
+
|
239 |
+
3. To the extent possible, the Licensor waives any right to
|
240 |
+
collect royalties from You for the exercise of the Licensed
|
241 |
+
Rights, whether directly or through a collecting society
|
242 |
+
under any voluntary or waivable statutory or compulsory
|
243 |
+
licensing scheme. In all other cases the Licensor expressly
|
244 |
+
reserves any right to collect such royalties, including when
|
245 |
+
the Licensed Material is used other than for NonCommercial
|
246 |
+
purposes.
|
247 |
+
|
248 |
+
Section 3 -- License Conditions.
|
249 |
+
|
250 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
251 |
+
following conditions.
|
252 |
+
|
253 |
+
a. Attribution.
|
254 |
+
|
255 |
+
1. If You Share the Licensed Material (including in modified
|
256 |
+
form), You must:
|
257 |
+
|
258 |
+
a. retain the following if it is supplied by the Licensor
|
259 |
+
with the Licensed Material:
|
260 |
+
|
261 |
+
i. identification of the creator(s) of the Licensed
|
262 |
+
Material and any others designated to receive
|
263 |
+
attribution, in any reasonable manner requested by
|
264 |
+
the Licensor (including by pseudonym if
|
265 |
+
designated);
|
266 |
+
|
267 |
+
ii. a copyright notice;
|
268 |
+
|
269 |
+
iii. a notice that refers to this Public License;
|
270 |
+
|
271 |
+
iv. a notice that refers to the disclaimer of
|
272 |
+
warranties;
|
273 |
+
|
274 |
+
v. a URI or hyperlink to the Licensed Material to the
|
275 |
+
extent reasonably practicable;
|
276 |
+
|
277 |
+
b. indicate if You modified the Licensed Material and
|
278 |
+
retain an indication of any previous modifications; and
|
279 |
+
|
280 |
+
c. indicate the Licensed Material is licensed under this
|
281 |
+
Public License, and include the text of, or the URI or
|
282 |
+
hyperlink to, this Public License.
|
283 |
+
|
284 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
285 |
+
reasonable manner based on the medium, means, and context in
|
286 |
+
which You Share the Licensed Material. For example, it may be
|
287 |
+
reasonable to satisfy the conditions by providing a URI or
|
288 |
+
hyperlink to a resource that includes the required
|
289 |
+
information.
|
290 |
+
|
291 |
+
3. If requested by the Licensor, You must remove any of the
|
292 |
+
information required by Section 3(a)(1)(A) to the extent
|
293 |
+
reasonably practicable.
|
294 |
+
|
295 |
+
4. If You Share Adapted Material You produce, the Adapter's
|
296 |
+
License You apply must not prevent recipients of the Adapted
|
297 |
+
Material from complying with this Public License.
|
298 |
+
|
299 |
+
Section 4 -- Sui Generis Database Rights.
|
300 |
+
|
301 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
302 |
+
apply to Your use of the Licensed Material:
|
303 |
+
|
304 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
305 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
306 |
+
portion of the contents of the database for NonCommercial purposes
|
307 |
+
only;
|
308 |
+
|
309 |
+
b. if You include all or a substantial portion of the database
|
310 |
+
contents in a database in which You have Sui Generis Database
|
311 |
+
Rights, then the database in which You have Sui Generis Database
|
312 |
+
Rights (but not its individual contents) is Adapted Material; and
|
313 |
+
|
314 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
315 |
+
all or a substantial portion of the contents of the database.
|
316 |
+
|
317 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
318 |
+
replace Your obligations under this Public License where the Licensed
|
319 |
+
Rights include other Copyright and Similar Rights.
|
320 |
+
|
321 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
322 |
+
|
323 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
324 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
325 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
326 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
327 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
328 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
329 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
330 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
331 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
332 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
333 |
+
|
334 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
335 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
336 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
337 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
338 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
339 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
340 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
341 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
342 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
343 |
+
|
344 |
+
c. The disclaimer of warranties and limitation of liability provided
|
345 |
+
above shall be interpreted in a manner that, to the extent
|
346 |
+
possible, most closely approximates an absolute disclaimer and
|
347 |
+
waiver of all liability.
|
348 |
+
|
349 |
+
Section 6 -- Term and Termination.
|
350 |
+
|
351 |
+
a. This Public License applies for the term of the Copyright and
|
352 |
+
Similar Rights licensed here. However, if You fail to comply with
|
353 |
+
this Public License, then Your rights under this Public License
|
354 |
+
terminate automatically.
|
355 |
+
|
356 |
+
b. Where Your right to use the Licensed Material has terminated under
|
357 |
+
Section 6(a), it reinstates:
|
358 |
+
|
359 |
+
1. automatically as of the date the violation is cured, provided
|
360 |
+
it is cured within 30 days of Your discovery of the
|
361 |
+
violation; or
|
362 |
+
|
363 |
+
2. upon express reinstatement by the Licensor.
|
364 |
+
|
365 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
366 |
+
right the Licensor may have to seek remedies for Your violations
|
367 |
+
of this Public License.
|
368 |
+
|
369 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
370 |
+
Licensed Material under separate terms or conditions or stop
|
371 |
+
distributing the Licensed Material at any time; however, doing so
|
372 |
+
will not terminate this Public License.
|
373 |
+
|
374 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
375 |
+
License.
|
376 |
+
|
377 |
+
Section 7 -- Other Terms and Conditions.
|
378 |
+
|
379 |
+
a. The Licensor shall not be bound by any additional or different
|
380 |
+
terms or conditions communicated by You unless expressly agreed.
|
381 |
+
|
382 |
+
b. Any arrangements, understandings, or agreements regarding the
|
383 |
+
Licensed Material not stated herein are separate from and
|
384 |
+
independent of the terms and conditions of this Public License.
|
385 |
+
|
386 |
+
Section 8 -- Interpretation.
|
387 |
+
|
388 |
+
a. For the avoidance of doubt, this Public License does not, and
|
389 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
390 |
+
conditions on any use of the Licensed Material that could lawfully
|
391 |
+
be made without permission under this Public License.
|
392 |
+
|
393 |
+
b. To the extent possible, if any provision of this Public License is
|
394 |
+
deemed unenforceable, it shall be automatically reformed to the
|
395 |
+
minimum extent necessary to make it enforceable. If the provision
|
396 |
+
cannot be reformed, it shall be severed from this Public License
|
397 |
+
without affecting the enforceability of the remaining terms and
|
398 |
+
conditions.
|
399 |
+
|
400 |
+
c. No term or condition of this Public License will be waived and no
|
401 |
+
failure to comply consented to unless expressly agreed to by the
|
402 |
+
Licensor.
|
403 |
+
|
404 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
405 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
406 |
+
that apply to the Licensor or You, including from the legal
|
407 |
+
processes of any jurisdiction or authority.
|
408 |
+
|
409 |
+
=======================================================================
|
410 |
+
|
411 |
+
Creative Commons is not a party to its public
|
412 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
413 |
+
its public licenses to material it publishes and in those instances
|
414 |
+
will be considered the “Licensor.” The text of the Creative Commons
|
415 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
416 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
417 |
+
material is shared under a Creative Commons public license or as
|
418 |
+
otherwise permitted by the Creative Commons policies published at
|
419 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
420 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
421 |
+
of Creative Commons without its prior written consent including,
|
422 |
+
without limitation, in connection with any unauthorized modifications
|
423 |
+
to any of its public licenses or any other arrangements,
|
424 |
+
understandings, or agreements concerning use of licensed material. For
|
425 |
+
the avoidance of doubt, this paragraph does not form part of the
|
426 |
+
public licenses.
|
427 |
+
|
428 |
+
Creative Commons may be contacted at creativecommons.org.
|
RepCodec/README.md
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# RepCodec: A Speech Representation Codec for Speech Tokenization
|
2 |
+
|
3 |
+
> [**RepCodec: A Speech Representation Codec for Speech Tokenization**](https://arxiv.org/abs/2309.00169)
|
4 |
+
|
5 |
+
## Introduction
|
6 |
+
|
7 |
+
**RepCodec** is a speech tokenization method for converting a speech waveform into a sequence of discrete semantic
|
8 |
+
tokens.
|
9 |
+
The main idea is to train a representation codec which learns a vector quantization codebook through reconstructing the
|
10 |
+
input speech representations from speech encoders like HuBERT or data2vec.
|
11 |
+
Extensive experiments show that RepCodec significantly outperforms the widely used k-means clustering approach in both
|
12 |
+
speech understanding and generation.
|
13 |
+
Also, RepCodec generalizes well across various speech encoders and languages.
|
14 |
+
|
15 |
+
<img src="images/RepCodec.png" alt="se" width="1000" />
|
16 |
+
|
17 |
+
## RepCodec Models
|
18 |
+
|
19 |
+
| Feature Type | Speech Data | RepCodec Model |
|
20 |
+
|-----------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------|----------------------------------------------------------------------------------------------------------|
|
21 |
+
| [HuBERT base](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert#pre-trained-and-fine-tuned-asr-models) layer 9 | [Librispeech](http://www.openslr.org/12) train-clean-100 | [hubert_base_l9](https://drive.google.com/file/d/1XD0HKl607FFjri2-VJT7lHQeSpxsCCFO/view?usp=sharing) |
|
22 |
+
| [HuBERT large](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert#pre-trained-and-fine-tuned-asr-models) layer 18 | [Librispeech](http://www.openslr.org/12) train-clean-100 | [hubert_large_l18](https://drive.google.com/file/d/1mTbm5GeJ7gp_5L3QLP-JGXdf8RnRw5n6/view?usp=sharing) |
|
23 |
+
| [data2vec base](https://github.com/facebookresearch/fairseq/blob/main/examples/data2vec/README.md#speech-2) layer 6 | [Librispeech](http://www.openslr.org/12) train-clean-100 | [data2vec_base_l6](https://drive.google.com/file/d/1d8sf3Ko_fYM9zlaiwxK_4xusLRKV5EMd/view?usp=sharing) |
|
24 |
+
| [data2vec large](https://github.com/facebookresearch/fairseq/blob/main/examples/data2vec/README.md#speech-2) layer 18 | [Librispeech](http://www.openslr.org/12) train-clean-100 | [data2vec_large_l18](https://drive.google.com/file/d/1nuRIHaejT-uVi4cluftbT8o_JZqar5SU/view?usp=sharing) |
|
25 |
+
| [Whisper medium](https://github.com/openai/whisper/tree/main#available-models-and-languages) layer 24 | [Librispeech](http://www.openslr.org/12) train-clean-100 | [whisper_medium_l24](https://drive.google.com/file/d/1V6YJSA2V4iywXrecJAN0oqsa3aHowexZ/view?usp=sharing) |
|
26 |
+
| [Whisper large-v2](https://github.com/openai/whisper/tree/main#available-models-and-languages) layer 32 | [Librispeech](http://www.openslr.org/12) train-clean-100 | [whisper_large_l32](https://drive.google.com/file/d/1k_X7ZMPg8iOeDrIJe70v6CHfFygzufXC/view?usp=sharing) |
|
27 |
+
|
28 |
+
## Speech Tokenization Using Pre-Trained Models
|
29 |
+
|
30 |
+
### Installation
|
31 |
+
|
32 |
+
Please first install RepCodec by
|
33 |
+
|
34 |
+
```
|
35 |
+
git clone https://github.com/mct10/RepCodec.git
|
36 |
+
cd RepCodec
|
37 |
+
pip install .
|
38 |
+
```
|
39 |
+
|
40 |
+
We used Python 3.9.18 and PyTorch 1.12.1 to test the usage, but the code should be compatible with other recent Python
|
41 |
+
and PyTorch versions.
|
42 |
+
|
43 |
+
### Representation Preparation
|
44 |
+
|
45 |
+
We adapt the `dump_hubert_feature.py` script
|
46 |
+
from [fairseq](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert/simple_kmeans#hubert-feature)
|
47 |
+
to support dumping representations from **data2vec**, **HuBERT**, or **Whisper** encoders.
|
48 |
+
|
49 |
+
If you use our script (`examples/dump_feature.py`), please also install the following packages:
|
50 |
+
|
51 |
+
```
|
52 |
+
pip install npy_append_array soundfile
|
53 |
+
```
|
54 |
+
|
55 |
+
Additionally, if you want to dump representations from
|
56 |
+
|
57 |
+
- **data2vec** or **HuBERT**: please
|
58 |
+
follow [fairseq's instruction](https://github.com/facebookresearch/fairseq#requirements-and-installation) to install
|
59 |
+
the latest fairseq.
|
60 |
+
|
61 |
+
- **Whisper**: please follow [Whispers'instruction](https://github.com/openai/whisper/tree/main#setup) to install the
|
62 |
+
latest
|
63 |
+
Whisper.
|
64 |
+
|
65 |
+
Then, you can follow the given examples to dump representations:
|
66 |
+
|
67 |
+
```
|
68 |
+
# Example 1: dump from HuBERT base layer 9
|
69 |
+
# (for data2vec, simply change "model_type" to data2vec and "ckpt_path" to the path of data2vec model)
|
70 |
+
|
71 |
+
layer=9
|
72 |
+
|
73 |
+
python3 examples/dump_feature.py \
|
74 |
+
--model_type hubert \
|
75 |
+
--tsv_path /path/to/tsv/file \
|
76 |
+
--ckpt_path /path/to/HuBERT/model \
|
77 |
+
--layer ${layer} \
|
78 |
+
--feat_dir /dir/to/save/representations
|
79 |
+
|
80 |
+
|
81 |
+
# Example 2: dump from Whisper medium layer 24
|
82 |
+
|
83 |
+
layer=24
|
84 |
+
|
85 |
+
python3 examples/dump_feature.py \
|
86 |
+
--model_type whisper \
|
87 |
+
--tsv_path /path/to/tsv/file \
|
88 |
+
--whisper_root /directory/to/save/whisper/model \
|
89 |
+
--whisper_name medium \
|
90 |
+
--layer ${layer} \
|
91 |
+
--feat_dir /dir/to/save/representations
|
92 |
+
```
|
93 |
+
|
94 |
+
Explanations about the args:
|
95 |
+
|
96 |
+
- **model_type:** choose from `data2vec`, `hubert`, and `whisper`.
|
97 |
+
|
98 |
+
- **tsv_path:** path of the tsv file.
|
99 |
+
Should have the format of
|
100 |
+
|
101 |
+
```
|
102 |
+
/dir/to/dataset
|
103 |
+
path_of_utterance_1 number_of_frames
|
104 |
+
path_of_utterance_2 number_of_frames
|
105 |
+
```
|
106 |
+
|
107 |
+
You can follow [this script](https://github.com/facebookresearch/fairseq/blob/main/examples/wav2vec/wav2vec_manifest.py)
|
108 |
+
to generate the tsv file.
|
109 |
+
|
110 |
+
For example, by running
|
111 |
+
|
112 |
+
```
|
113 |
+
python wav2vec_manifest.py \
|
114 |
+
/dir/to/LibriSpeech/dev-clean \
|
115 |
+
--dest /dir/to/manifest \
|
116 |
+
--ext flac \
|
117 |
+
--valid-percent 0
|
118 |
+
```
|
119 |
+
|
120 |
+
you can obtain the `dev-clean.tsv` in `/dir/to/manifest` for LibriSpeech. (By default, the output file name
|
121 |
+
is `train.tsv`. Remember to rename the file.)
|
122 |
+
|
123 |
+
It should be similar to:
|
124 |
+
|
125 |
+
```
|
126 |
+
/dir/to/LibriSpeech/dev-clean
|
127 |
+
2277/149896/2277-149896-0026.flac 78720
|
128 |
+
2277/149896/2277-149896-0005.flac 89600
|
129 |
+
2277/149896/2277-149896-0033.flac 45520
|
130 |
+
```
|
131 |
+
|
132 |
+
- **ckpt_path**:
|
133 |
+
must provide for data2vec and HuBERT.
|
134 |
+
You need to download the model
|
135 |
+
from [data2vec website](https://github.com/facebookresearch/fairseq/blob/main/examples/data2vec/README.md#speech-2)
|
136 |
+
or [HuBERT website](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert#pre-trained-and-fine-tuned-asr-models)
|
137 |
+
yourself.
|
138 |
+
`--ckpt_path` is the path of the data2vec/HuBERT model.
|
139 |
+
- **whisper_root** and **whisper_name**:
|
140 |
+
must provide **BOTH** `--whisper_root` and `--whisper_name` for Whisper.
|
141 |
+
If there is no corresponding model in `--whisper_root`, the script will download for you.
|
142 |
+
|
143 |
+
- **layer**:
|
144 |
+
which Transformer encoder layer of the model should the representations be extracted from.
|
145 |
+
It is **1-based**.
|
146 |
+
For example, if layer=9, then the outputs from the 9<sup>th</sup> Transformer encoder layer are dumped.
|
147 |
+
Range: [1, number of Transformer encoder layers]
|
148 |
+
|
149 |
+
- **feat_dir**: The output representations will be saved to `${feat_dir}/0_1.npy`
|
150 |
+
and `${feat_dir}/0_1.len`.
|
151 |
+
|
152 |
+
For other useful functionalities (e.g., sharding), please check the argument list in `examples/dump_feature.py`.
|
153 |
+
|
154 |
+
### Command Line Usage
|
155 |
+
|
156 |
+
We expect to have `${feat_dir}/0_1.npy` and `${feat_dir}/0_1.len` in the provided
|
157 |
+
directory `/dir/to/representaitons`.
|
158 |
+
|
159 |
+
Also, the tsv file should be the **same** as the one used in [Representation Preparation](#representation-preparation).
|
160 |
+
|
161 |
+
```
|
162 |
+
repcodec /dir/to/representaitons \
|
163 |
+
--model /path/to/repcodec/model \
|
164 |
+
--tsv_path /path/to/tsv/file \
|
165 |
+
[--model_config_path /path/to/train/config] \
|
166 |
+
[--use_gpu] \
|
167 |
+
[--out_dir /path/to/output]
|
168 |
+
```
|
169 |
+
|
170 |
+
If you trained the model yourself following [Training New RepCodec Models](#training-new-repcodec-models),
|
171 |
+
please provide the training config file using `--model_config_path`.
|
172 |
+
If you use the model we provide [here](#repcodec-models), then you do not have to provide that.
|
173 |
+
|
174 |
+
This command will tokenize the representations and the output discrete tokens will be saved to `${out_dir}/tokens`.
|
175 |
+
The tokens are in the same order as the provided tsv file.
|
176 |
+
|
177 |
+
An example of the output file:
|
178 |
+
|
179 |
+
```
|
180 |
+
/dir/to/LibriSpeech/dev-clean
|
181 |
+
2277/149896/2277-149896-0026.flac 696 696 198 198 198 498 ...
|
182 |
+
2277/149896/2277-149896-0005.flac 696 696 198 198 198 907 ...
|
183 |
+
2277/149896/2277-149896-0033.flac 696 696 198 198 198 696 ...
|
184 |
+
```
|
185 |
+
|
186 |
+
Under `examples/tokens`, we provide some token files as references. They are obtained from LibriSpeech dev-clean subset
|
187 |
+
using the 6 types of representations and corresponding [RepCodec Models](#repcodec-models).
|
188 |
+
Your results should be very similar to ours.
|
189 |
+
|
190 |
+
### Python Usage
|
191 |
+
|
192 |
+
```python
|
193 |
+
import torch
|
194 |
+
import yaml
|
195 |
+
|
196 |
+
from repcodec.RepCodec import RepCodec
|
197 |
+
|
198 |
+
# for feature types of HubERT base & data2vec base, please use repcodec_dim768.yaml;
|
199 |
+
# for feature types of HuBERT large & data2vec large & Whisper medium, please use repcodec_dim1024.yaml;
|
200 |
+
# for feature types of Whisper large-v2, please use repcodec_dim1280.yaml
|
201 |
+
config = "repcodec/configs/repcodec_dim768.yaml"
|
202 |
+
with open(config) as fp:
|
203 |
+
conf = yaml.load(fp, Loader=yaml.FullLoader)
|
204 |
+
|
205 |
+
model = RepCodec(**conf)
|
206 |
+
model.load_state_dict(torch.load("./hubert_base_l9.pkl", map_location="cpu")["model"]["repcodec"])
|
207 |
+
model.quantizer.initial()
|
208 |
+
model.eval()
|
209 |
+
|
210 |
+
# input shape: (batch size, hidden dim, sequence length)
|
211 |
+
random_features = torch.randn(size=(1, 768, 100))
|
212 |
+
with torch.no_grad():
|
213 |
+
x = model.encoder(random_features)
|
214 |
+
z = model.projector(x)
|
215 |
+
_, idx = model.quantizer.codebook.forward_index(z.transpose(2, 1))
|
216 |
+
tokens = idx.cpu().data.numpy().tolist()[0]
|
217 |
+
```
|
218 |
+
|
219 |
+
## Training New RepCodec Models
|
220 |
+
|
221 |
+
We use a config file to set up all the training configurations, e.g., data, model architecture,
|
222 |
+
optimizer, scheduler.
|
223 |
+
We provide an example [here](./train_configs/ex_dim768_mse.yaml).
|
224 |
+
|
225 |
+
Please first install required packages following [Installation](#installation)
|
226 |
+
and prepare the representations following [Representation Preparation](#representation-preparation).
|
227 |
+
|
228 |
+
The input data directory is expected to have the following structure
|
229 |
+
```
|
230 |
+
/dir/to/representations/
|
231 |
+
train_set_name/
|
232 |
+
0_1.npy
|
233 |
+
0_1.len
|
234 |
+
valid_set_name/
|
235 |
+
0_1.npy
|
236 |
+
0_1.len
|
237 |
+
test_set_name/
|
238 |
+
0_1.npy
|
239 |
+
0_1.len
|
240 |
+
```
|
241 |
+
|
242 |
+
The names of subsets should be the same as the fields in the config file.
|
243 |
+
|
244 |
+
Then, you can run training by
|
245 |
+
```
|
246 |
+
python train.py \
|
247 |
+
-c /path/to/config/file \
|
248 |
+
--tag $tag \
|
249 |
+
--exp_root exp
|
250 |
+
```
|
251 |
+
|
252 |
+
`tag` is the name of the output folder.
|
253 |
+
All outputs will be saved to `exp_root/tag/`.
|
254 |
+
|
255 |
+
## Acknowledge
|
256 |
+
|
257 |
+
Our implementation is based on [facebookresearch/AudioDec](https://github.com/facebookresearch/AudioDec).
|
258 |
+
We thank them for open-sourcing their code!
|
259 |
+
|
260 |
+
## Citation
|
261 |
+
|
262 |
+
If you find our work useful, please cite the following article.
|
263 |
+
|
264 |
+
```
|
265 |
+
@misc{huang2023repcodec,
|
266 |
+
title={RepCodec: A Speech Representation Codec for Speech Tokenization},
|
267 |
+
author={Zhichao Huang and Chutong Meng and Tom Ko},
|
268 |
+
year={2023},
|
269 |
+
eprint={2309.00169},
|
270 |
+
archivePrefix={arXiv},
|
271 |
+
primaryClass={eess.AS}
|
272 |
+
}
|
273 |
+
```
|
RepCodec/examples/data2vec_audio.py
ADDED
@@ -0,0 +1,541 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on fairseq (https://github.com/facebookresearch/fairseq)
|
7 |
+
|
8 |
+
# ref: https://github.com/facebookresearch/fairseq/blob/main/examples/data2vec/models/data2vec_audio.py
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import math
|
12 |
+
from dataclasses import dataclass, field
|
13 |
+
from typing import Optional
|
14 |
+
|
15 |
+
from omegaconf import II
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.nn.functional as F
|
20 |
+
import torch.distributed as dist
|
21 |
+
|
22 |
+
from fairseq.modules import EMAModule, EMAModuleConfig
|
23 |
+
from fairseq.data.data_utils import compute_mask_indices
|
24 |
+
from fairseq.models import BaseFairseqModel, register_model
|
25 |
+
from fairseq.models.wav2vec import (
|
26 |
+
ConvFeatureExtractionModel,
|
27 |
+
Wav2Vec2Config,
|
28 |
+
TransformerEncoder,
|
29 |
+
)
|
30 |
+
from fairseq.modules import (
|
31 |
+
GradMultiply,
|
32 |
+
LayerNorm,
|
33 |
+
)
|
34 |
+
from fairseq.utils import index_put
|
35 |
+
|
36 |
+
|
37 |
+
logger = logging.getLogger(__name__)
|
38 |
+
|
39 |
+
|
40 |
+
@dataclass
|
41 |
+
class Data2VecAudioConfig(Wav2Vec2Config):
|
42 |
+
|
43 |
+
loss_beta: float = field(
|
44 |
+
default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
|
45 |
+
)
|
46 |
+
loss_scale: Optional[float] = field(
|
47 |
+
default=None,
|
48 |
+
metadata={
|
49 |
+
"help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
|
50 |
+
},
|
51 |
+
)
|
52 |
+
average_top_k_layers: int = field(
|
53 |
+
default=8, metadata={"help": "how many layers to average"}
|
54 |
+
)
|
55 |
+
|
56 |
+
layer_norm_target_layer: bool = False
|
57 |
+
instance_norm_target_layer: bool = False
|
58 |
+
instance_norm_targets: bool = False
|
59 |
+
layer_norm_targets: bool = False
|
60 |
+
batch_norm_target_layer: bool = False
|
61 |
+
group_norm_target_layer: bool = False
|
62 |
+
|
63 |
+
ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
|
64 |
+
ema_end_decay: float = field(
|
65 |
+
default=0.9999, metadata={"help": "final ema decay rate"}
|
66 |
+
)
|
67 |
+
|
68 |
+
# when to finish annealing ema decay rate
|
69 |
+
ema_anneal_end_step: int = II("optimization.max_update")
|
70 |
+
|
71 |
+
ema_transformer_only: bool = field(
|
72 |
+
default=True,
|
73 |
+
metadata={"help": "whether to momentum update only the transformer"},
|
74 |
+
)
|
75 |
+
ema_layers_only: bool = field(
|
76 |
+
default=True,
|
77 |
+
metadata={"help": "whether to momentum update only the transformer layers"},
|
78 |
+
)
|
79 |
+
|
80 |
+
max_update: int = II("optimization.max_update")
|
81 |
+
|
82 |
+
min_target_var: float = field(
|
83 |
+
default=0.1, metadata={"help": "stop training if target var falls below this"}
|
84 |
+
)
|
85 |
+
min_pred_var: float = field(
|
86 |
+
default=0.01,
|
87 |
+
metadata={"help": "stop training if prediction var falls below this"},
|
88 |
+
)
|
89 |
+
|
90 |
+
|
91 |
+
def get_annealed_rate(start, end, curr_step, total_steps):
|
92 |
+
r = end - start
|
93 |
+
pct_remaining = 1 - curr_step / total_steps
|
94 |
+
return end - r * pct_remaining
|
95 |
+
|
96 |
+
|
97 |
+
@register_model("data2vec_audio", dataclass=Data2VecAudioConfig)
|
98 |
+
class Data2VecAudioModel(BaseFairseqModel):
|
99 |
+
def __init__(self, cfg: Data2VecAudioConfig):
|
100 |
+
super().__init__()
|
101 |
+
self.cfg = cfg
|
102 |
+
|
103 |
+
feature_enc_layers = eval(cfg.conv_feature_layers)
|
104 |
+
self.extractor_embed = feature_enc_layers[-1][0]
|
105 |
+
|
106 |
+
self.ema = None
|
107 |
+
self.embed = cfg.encoder_embed_dim
|
108 |
+
|
109 |
+
self.average_top_k_layers = cfg.average_top_k_layers
|
110 |
+
self.loss_beta = cfg.loss_beta
|
111 |
+
self.loss_scale = cfg.loss_scale
|
112 |
+
|
113 |
+
self.feature_extractor = ConvFeatureExtractionModel(
|
114 |
+
conv_layers=feature_enc_layers,
|
115 |
+
dropout=0.0,
|
116 |
+
mode=cfg.extractor_mode,
|
117 |
+
conv_bias=cfg.conv_bias,
|
118 |
+
)
|
119 |
+
|
120 |
+
self.post_extract_proj = nn.Linear(self.extractor_embed, cfg.encoder_embed_dim)
|
121 |
+
|
122 |
+
self.mask_prob = cfg.mask_prob
|
123 |
+
self.mask_selection = cfg.mask_selection
|
124 |
+
self.mask_other = cfg.mask_other
|
125 |
+
self.mask_length = cfg.mask_length
|
126 |
+
self.no_mask_overlap = cfg.no_mask_overlap
|
127 |
+
self.mask_min_space = cfg.mask_min_space
|
128 |
+
|
129 |
+
self.mask_channel_prob = cfg.mask_channel_prob
|
130 |
+
self.mask_channel_before = cfg.mask_channel_before
|
131 |
+
self.mask_channel_selection = cfg.mask_channel_selection
|
132 |
+
self.mask_channel_other = cfg.mask_channel_other
|
133 |
+
self.mask_channel_length = cfg.mask_channel_length
|
134 |
+
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
135 |
+
self.mask_channel_min_space = cfg.mask_channel_min_space
|
136 |
+
|
137 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
138 |
+
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
139 |
+
|
140 |
+
self.feature_grad_mult = cfg.feature_grad_mult
|
141 |
+
|
142 |
+
self.mask_emb = nn.Parameter(
|
143 |
+
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
|
144 |
+
)
|
145 |
+
|
146 |
+
self.encoder = TransformerEncoder(cfg)
|
147 |
+
self.layer_norm = LayerNorm(self.extractor_embed)
|
148 |
+
|
149 |
+
self.final_proj = nn.Linear(self.embed, self.embed)
|
150 |
+
|
151 |
+
self.num_updates = 0
|
152 |
+
|
153 |
+
def make_ema_teacher(self):
|
154 |
+
ema_config = EMAModuleConfig(
|
155 |
+
ema_decay=self.cfg.ema_decay,
|
156 |
+
ema_fp32=True,
|
157 |
+
)
|
158 |
+
skip_keys = set()
|
159 |
+
if self.cfg.ema_layers_only:
|
160 |
+
self.cfg.ema_transformer_only = True
|
161 |
+
for k, _ in self.encoder.pos_conv.named_parameters():
|
162 |
+
skip_keys.add(f"pos_conv.{k}")
|
163 |
+
|
164 |
+
self.ema = EMAModule(
|
165 |
+
self.encoder if self.cfg.ema_transformer_only else self,
|
166 |
+
ema_config,
|
167 |
+
skip_keys=skip_keys,
|
168 |
+
)
|
169 |
+
|
170 |
+
def set_num_updates(self, num_updates):
|
171 |
+
super().set_num_updates(num_updates)
|
172 |
+
|
173 |
+
if self.ema is None and self.final_proj is not None:
|
174 |
+
logger.info(f"making ema teacher")
|
175 |
+
self.make_ema_teacher()
|
176 |
+
elif self.training and self.ema is not None:
|
177 |
+
if self.cfg.ema_decay != self.cfg.ema_end_decay:
|
178 |
+
if num_updates >= self.cfg.ema_anneal_end_step:
|
179 |
+
decay = self.cfg.ema_end_decay
|
180 |
+
else:
|
181 |
+
decay = get_annealed_rate(
|
182 |
+
self.cfg.ema_decay,
|
183 |
+
self.cfg.ema_end_decay,
|
184 |
+
num_updates,
|
185 |
+
self.cfg.ema_anneal_end_step,
|
186 |
+
)
|
187 |
+
self.ema.set_decay(decay)
|
188 |
+
if self.ema.get_decay() < 1:
|
189 |
+
self.ema.step(self.encoder if self.cfg.ema_transformer_only else self)
|
190 |
+
|
191 |
+
self.num_updates = num_updates
|
192 |
+
|
193 |
+
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
194 |
+
state = super().state_dict(destination, prefix, keep_vars)
|
195 |
+
|
196 |
+
if self.ema is not None:
|
197 |
+
state[prefix + "_ema"] = self.ema.fp32_params
|
198 |
+
|
199 |
+
return state
|
200 |
+
|
201 |
+
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
202 |
+
if self.ema is not None:
|
203 |
+
k = prefix + "_ema"
|
204 |
+
assert k in state_dict
|
205 |
+
self.ema.restore(state_dict[k], True)
|
206 |
+
del state_dict[k]
|
207 |
+
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
208 |
+
|
209 |
+
@classmethod
|
210 |
+
def build_model(cls, cfg: Data2VecAudioConfig, task=None):
|
211 |
+
"""Build a new model instance."""
|
212 |
+
|
213 |
+
return cls(cfg)
|
214 |
+
|
215 |
+
def apply_mask(
|
216 |
+
self,
|
217 |
+
x,
|
218 |
+
padding_mask,
|
219 |
+
mask_indices=None,
|
220 |
+
mask_channel_indices=None,
|
221 |
+
):
|
222 |
+
B, T, C = x.shape
|
223 |
+
|
224 |
+
if self.mask_channel_prob > 0 and self.mask_channel_before:
|
225 |
+
mask_channel_indices = compute_mask_indices(
|
226 |
+
(B, C),
|
227 |
+
None,
|
228 |
+
self.mask_channel_prob,
|
229 |
+
self.mask_channel_length,
|
230 |
+
self.mask_channel_selection,
|
231 |
+
self.mask_channel_other,
|
232 |
+
no_overlap=self.no_mask_channel_overlap,
|
233 |
+
min_space=self.mask_channel_min_space,
|
234 |
+
)
|
235 |
+
mask_channel_indices = (
|
236 |
+
torch.from_numpy(mask_channel_indices)
|
237 |
+
.to(x.device)
|
238 |
+
.unsqueeze(1)
|
239 |
+
.expand(-1, T, -1)
|
240 |
+
)
|
241 |
+
x[mask_channel_indices] = 0
|
242 |
+
|
243 |
+
if self.mask_prob > 0:
|
244 |
+
if mask_indices is None:
|
245 |
+
mask_indices = compute_mask_indices(
|
246 |
+
(B, T),
|
247 |
+
padding_mask,
|
248 |
+
self.mask_prob,
|
249 |
+
self.mask_length,
|
250 |
+
self.mask_selection,
|
251 |
+
self.mask_other,
|
252 |
+
min_masks=1,
|
253 |
+
no_overlap=self.no_mask_overlap,
|
254 |
+
min_space=self.mask_min_space,
|
255 |
+
require_same_masks=self.cfg.require_same_masks,
|
256 |
+
mask_dropout=self.cfg.mask_dropout,
|
257 |
+
)
|
258 |
+
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
259 |
+
x = index_put(x, mask_indices, self.mask_emb)
|
260 |
+
else:
|
261 |
+
mask_indices = None
|
262 |
+
|
263 |
+
if self.mask_channel_prob > 0 and not self.mask_channel_before:
|
264 |
+
if mask_channel_indices is None:
|
265 |
+
mask_channel_indices = compute_mask_indices(
|
266 |
+
(B, C),
|
267 |
+
None,
|
268 |
+
self.mask_channel_prob,
|
269 |
+
self.mask_channel_length,
|
270 |
+
self.mask_channel_selection,
|
271 |
+
self.mask_channel_other,
|
272 |
+
no_overlap=self.no_mask_channel_overlap,
|
273 |
+
min_space=self.mask_channel_min_space,
|
274 |
+
)
|
275 |
+
mask_channel_indices = (
|
276 |
+
torch.from_numpy(mask_channel_indices)
|
277 |
+
.to(x.device)
|
278 |
+
.unsqueeze(1)
|
279 |
+
.expand(-1, T, -1)
|
280 |
+
)
|
281 |
+
x = index_put(x, mask_channel_indices, 0)
|
282 |
+
|
283 |
+
return x, mask_indices
|
284 |
+
|
285 |
+
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
|
286 |
+
"""
|
287 |
+
Computes the output length of the convolutional layers
|
288 |
+
"""
|
289 |
+
|
290 |
+
def _conv_out_length(input_length, kernel_size, stride):
|
291 |
+
return torch.floor((input_length - kernel_size) / stride + 1)
|
292 |
+
|
293 |
+
conv_cfg_list = eval(self.cfg.conv_feature_layers)
|
294 |
+
|
295 |
+
for i in range(len(conv_cfg_list)):
|
296 |
+
input_lengths = _conv_out_length(
|
297 |
+
input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]
|
298 |
+
)
|
299 |
+
|
300 |
+
return input_lengths.to(torch.long)
|
301 |
+
|
302 |
+
def forward(
|
303 |
+
self,
|
304 |
+
source,
|
305 |
+
padding_mask=None,
|
306 |
+
mask=True,
|
307 |
+
features_only=False,
|
308 |
+
layer=None,
|
309 |
+
mask_indices=None,
|
310 |
+
mask_channel_indices=None,
|
311 |
+
padding_count=None,
|
312 |
+
):
|
313 |
+
features = source
|
314 |
+
|
315 |
+
if self.feature_grad_mult > 0:
|
316 |
+
features = self.feature_extractor(features)
|
317 |
+
if self.feature_grad_mult != 1.0:
|
318 |
+
features = GradMultiply.apply(features, self.feature_grad_mult)
|
319 |
+
else:
|
320 |
+
with torch.no_grad():
|
321 |
+
features = self.feature_extractor(features)
|
322 |
+
|
323 |
+
features = features.transpose(1, 2)
|
324 |
+
|
325 |
+
features = self.layer_norm(features)
|
326 |
+
|
327 |
+
orig_padding_mask = padding_mask
|
328 |
+
|
329 |
+
if padding_mask is not None and padding_mask.any():
|
330 |
+
input_lengths = (1 - padding_mask.long()).sum(-1)
|
331 |
+
# apply conv formula to get real output_lengths
|
332 |
+
output_lengths = self._get_feat_extract_output_lengths(input_lengths)
|
333 |
+
|
334 |
+
padding_mask = torch.zeros(
|
335 |
+
features.shape[:2], dtype=features.dtype, device=features.device
|
336 |
+
)
|
337 |
+
|
338 |
+
# these two operations makes sure that all values
|
339 |
+
# before the output lengths indices are attended to
|
340 |
+
padding_mask[
|
341 |
+
(
|
342 |
+
torch.arange(padding_mask.shape[0], device=padding_mask.device),
|
343 |
+
output_lengths - 1,
|
344 |
+
)
|
345 |
+
] = 1
|
346 |
+
padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()
|
347 |
+
else:
|
348 |
+
padding_mask = None
|
349 |
+
|
350 |
+
if self.post_extract_proj is not None:
|
351 |
+
features = self.post_extract_proj(features)
|
352 |
+
|
353 |
+
pre_encoder_features = None
|
354 |
+
if self.cfg.ema_transformer_only:
|
355 |
+
pre_encoder_features = features.clone()
|
356 |
+
|
357 |
+
features = self.dropout_input(features)
|
358 |
+
|
359 |
+
if mask:
|
360 |
+
x, mask_indices = self.apply_mask(
|
361 |
+
features,
|
362 |
+
padding_mask,
|
363 |
+
mask_indices=mask_indices,
|
364 |
+
mask_channel_indices=mask_channel_indices,
|
365 |
+
)
|
366 |
+
else:
|
367 |
+
x = features
|
368 |
+
mask_indices = None
|
369 |
+
|
370 |
+
x, layer_results = self.encoder(
|
371 |
+
x,
|
372 |
+
padding_mask=padding_mask,
|
373 |
+
layer=layer,
|
374 |
+
)
|
375 |
+
|
376 |
+
if features_only:
|
377 |
+
return {
|
378 |
+
"x": x,
|
379 |
+
"padding_mask": padding_mask,
|
380 |
+
"layer_results": layer_results,
|
381 |
+
}
|
382 |
+
|
383 |
+
result = {
|
384 |
+
"losses": {},
|
385 |
+
}
|
386 |
+
|
387 |
+
with torch.no_grad():
|
388 |
+
self.ema.model.eval()
|
389 |
+
|
390 |
+
if self.cfg.ema_transformer_only:
|
391 |
+
y, layer_results = self.ema.model.extract_features(
|
392 |
+
pre_encoder_features,
|
393 |
+
padding_mask=padding_mask,
|
394 |
+
min_layer=self.cfg.encoder_layers - self.average_top_k_layers,
|
395 |
+
)
|
396 |
+
y = {
|
397 |
+
"x": y,
|
398 |
+
"padding_mask": padding_mask,
|
399 |
+
"layer_results": layer_results,
|
400 |
+
}
|
401 |
+
else:
|
402 |
+
y = self.ema.model.extract_features(
|
403 |
+
source=source,
|
404 |
+
padding_mask=orig_padding_mask,
|
405 |
+
mask=False,
|
406 |
+
)
|
407 |
+
|
408 |
+
target_layer_results = [l[2] for l in y["layer_results"]]
|
409 |
+
|
410 |
+
permuted = False
|
411 |
+
if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
|
412 |
+
target_layer_results = [
|
413 |
+
tl.permute(1, 2, 0) for tl in target_layer_results # TBC -> BCT
|
414 |
+
]
|
415 |
+
permuted = True
|
416 |
+
|
417 |
+
if self.cfg.batch_norm_target_layer:
|
418 |
+
target_layer_results = [
|
419 |
+
F.batch_norm(
|
420 |
+
tl.float(), running_mean=None, running_var=None, training=True
|
421 |
+
)
|
422 |
+
for tl in target_layer_results
|
423 |
+
]
|
424 |
+
|
425 |
+
if self.cfg.instance_norm_target_layer:
|
426 |
+
target_layer_results = [
|
427 |
+
F.instance_norm(tl.float()) for tl in target_layer_results
|
428 |
+
]
|
429 |
+
|
430 |
+
if permuted:
|
431 |
+
target_layer_results = [
|
432 |
+
tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC
|
433 |
+
]
|
434 |
+
|
435 |
+
if self.cfg.group_norm_target_layer:
|
436 |
+
target_layer_results = [
|
437 |
+
F.layer_norm(tl.float(), tl.shape[-2:])
|
438 |
+
for tl in target_layer_results
|
439 |
+
]
|
440 |
+
|
441 |
+
if self.cfg.layer_norm_target_layer:
|
442 |
+
target_layer_results = [
|
443 |
+
F.layer_norm(tl.float(), tl.shape[-1:])
|
444 |
+
for tl in target_layer_results
|
445 |
+
]
|
446 |
+
|
447 |
+
y = sum(target_layer_results) / len(target_layer_results)
|
448 |
+
|
449 |
+
if self.cfg.layer_norm_targets:
|
450 |
+
y = F.layer_norm(y.float(), y.shape[-1:])
|
451 |
+
|
452 |
+
if self.cfg.instance_norm_targets:
|
453 |
+
y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2)
|
454 |
+
|
455 |
+
if not permuted:
|
456 |
+
y = y.transpose(0, 1)
|
457 |
+
|
458 |
+
y = y[mask_indices]
|
459 |
+
|
460 |
+
x = x[mask_indices]
|
461 |
+
x = self.final_proj(x)
|
462 |
+
|
463 |
+
sz = x.size(-1)
|
464 |
+
|
465 |
+
if self.loss_beta == 0:
|
466 |
+
loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1)
|
467 |
+
else:
|
468 |
+
loss = F.smooth_l1_loss(
|
469 |
+
x.float(), y.float(), reduction="none", beta=self.loss_beta
|
470 |
+
).sum(dim=-1)
|
471 |
+
|
472 |
+
if self.loss_scale is not None:
|
473 |
+
scale = self.loss_scale
|
474 |
+
else:
|
475 |
+
scale = 1 / math.sqrt(sz)
|
476 |
+
|
477 |
+
result["losses"]["regression"] = loss.sum() * scale
|
478 |
+
|
479 |
+
if "sample_size" not in result:
|
480 |
+
result["sample_size"] = loss.numel()
|
481 |
+
|
482 |
+
with torch.no_grad():
|
483 |
+
result["target_var"] = self.compute_var(y)
|
484 |
+
result["pred_var"] = self.compute_var(x.float())
|
485 |
+
|
486 |
+
if self.num_updates > 5000 and result["target_var"] < self.cfg.min_target_var:
|
487 |
+
logger.error(
|
488 |
+
f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
|
489 |
+
)
|
490 |
+
raise Exception(
|
491 |
+
f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
|
492 |
+
)
|
493 |
+
if self.num_updates > 5000 and result["pred_var"] < self.cfg.min_pred_var:
|
494 |
+
logger.error(
|
495 |
+
f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
|
496 |
+
)
|
497 |
+
raise Exception(
|
498 |
+
f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
|
499 |
+
)
|
500 |
+
|
501 |
+
if self.ema is not None:
|
502 |
+
result["ema_decay"] = self.ema.get_decay() * 1000
|
503 |
+
|
504 |
+
return result
|
505 |
+
|
506 |
+
@staticmethod
|
507 |
+
def compute_var(y):
|
508 |
+
y = y.view(-1, y.size(-1))
|
509 |
+
if dist.is_initialized():
|
510 |
+
zc = torch.tensor(y.size(0)).cuda()
|
511 |
+
zs = y.sum(dim=0)
|
512 |
+
zss = (y ** 2).sum(dim=0)
|
513 |
+
|
514 |
+
dist.all_reduce(zc)
|
515 |
+
dist.all_reduce(zs)
|
516 |
+
dist.all_reduce(zss)
|
517 |
+
|
518 |
+
var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
|
519 |
+
return torch.sqrt(var + 1e-6).mean()
|
520 |
+
else:
|
521 |
+
return torch.sqrt(y.var(dim=0) + 1e-6).mean()
|
522 |
+
|
523 |
+
def extract_features(
|
524 |
+
self, source, padding_mask, mask=False, layer=None
|
525 |
+
):
|
526 |
+
res = self.forward(
|
527 |
+
source,
|
528 |
+
padding_mask,
|
529 |
+
mask=mask,
|
530 |
+
features_only=True,
|
531 |
+
layer=layer,
|
532 |
+
)
|
533 |
+
return res
|
534 |
+
|
535 |
+
def remove_pretraining_modules(self, last_layer=None):
|
536 |
+
self.final_proj = None
|
537 |
+
self.ema = None
|
538 |
+
if last_layer is not None:
|
539 |
+
self.encoder.layers = nn.ModuleList(
|
540 |
+
l for i, l in enumerate(self.encoder.layers) if i <= last_layer
|
541 |
+
)
|
RepCodec/examples/data2vec_feature_reader.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on fairseq (https://github.com/facebookresearch/fairseq)
|
7 |
+
|
8 |
+
import logging
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from fairseq import tasks
|
13 |
+
from fairseq.checkpoint_utils import load_checkpoint_to_cpu
|
14 |
+
from fairseq.data.audio.audio_utils import get_features_or_waveform
|
15 |
+
from omegaconf import OmegaConf
|
16 |
+
|
17 |
+
from data2vec_audio import Data2VecAudioModel
|
18 |
+
|
19 |
+
logger = logging.getLogger("dump_feature")
|
20 |
+
|
21 |
+
|
22 |
+
class Data2vecFeatureReader(object):
|
23 |
+
def __init__(self, ckpt_path: str, layer: int, device: str, max_chunk=1600000):
|
24 |
+
state = load_checkpoint_to_cpu(ckpt_path)
|
25 |
+
cfg = state["cfg"]
|
26 |
+
# load task
|
27 |
+
task = tasks.setup_task(cfg.task, from_checkpoint=True)
|
28 |
+
task.load_state_dict(state["task_state"])
|
29 |
+
# load model config
|
30 |
+
if "layer_type" not in cfg.model:
|
31 |
+
# fix a missing key
|
32 |
+
model_config = {k: v for k, v in cfg.model.items()}
|
33 |
+
model_config["layer_type"] = "transformer"
|
34 |
+
model_config = OmegaConf.create(model_config)
|
35 |
+
else:
|
36 |
+
model_config = cfg.model
|
37 |
+
|
38 |
+
# fix param name in the state
|
39 |
+
state["model"]["final_proj.weight"] = state["model"].pop("final_proj.0.weight")
|
40 |
+
state["model"]["final_proj.bias"] = state["model"].pop("final_proj.0.bias")
|
41 |
+
del state["model"]["_ema"]
|
42 |
+
|
43 |
+
# load model
|
44 |
+
model = Data2VecAudioModel.build_model(model_config)
|
45 |
+
model.load_state_dict(
|
46 |
+
state["model"], strict=True, model_cfg=model_config
|
47 |
+
)
|
48 |
+
|
49 |
+
self.device = device
|
50 |
+
logger.info(f"device = {self.device}")
|
51 |
+
|
52 |
+
self.model = model.eval().to(self.device)
|
53 |
+
self.task = task
|
54 |
+
self.layer = layer - 1 # make it 1-based
|
55 |
+
self.max_chunk = max_chunk
|
56 |
+
logger.info(f"TASK CONFIG:\n{self.task.cfg}")
|
57 |
+
logger.info(f" max_chunk = {self.max_chunk}")
|
58 |
+
|
59 |
+
def read_audio(self, path, ref_len=None):
|
60 |
+
wav = get_features_or_waveform(path, need_waveform=True, use_sample_rate=self.task.cfg.sample_rate)
|
61 |
+
if wav.ndim == 2:
|
62 |
+
wav = wav.mean(-1)
|
63 |
+
assert wav.ndim == 1, wav.ndim
|
64 |
+
if ref_len is not None and abs(ref_len - len(wav)) > 160:
|
65 |
+
logger.warning(f"ref {ref_len} != read {len(wav)} ({path})")
|
66 |
+
return wav
|
67 |
+
|
68 |
+
def get_feats(self, path, ref_len=None):
|
69 |
+
x = self.read_audio(path, ref_len=ref_len)
|
70 |
+
with torch.no_grad():
|
71 |
+
x = torch.from_numpy(x).float().to(self.device)
|
72 |
+
if self.task.cfg.normalize:
|
73 |
+
x = F.layer_norm(x, x.shape)
|
74 |
+
x = x.view(1, -1)
|
75 |
+
|
76 |
+
feat = []
|
77 |
+
for start in range(0, x.size(1), self.max_chunk):
|
78 |
+
x_chunk = x[:, start: start + self.max_chunk]
|
79 |
+
res = self.model.extract_features(
|
80 |
+
source=x_chunk,
|
81 |
+
padding_mask=None,
|
82 |
+
mask=False,
|
83 |
+
layer=self.layer,
|
84 |
+
)
|
85 |
+
feat_chunk = res["x"]
|
86 |
+
feat.append(feat_chunk)
|
87 |
+
return torch.cat(feat, 1).squeeze(0)
|
RepCodec/examples/dump_feature.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on fairseq (https://github.com/facebookresearch/fairseq)
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
|
12 |
+
from feature_utils import get_path_iterator, dump_feature
|
13 |
+
|
14 |
+
logging.basicConfig(
|
15 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
16 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
17 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
18 |
+
stream=sys.stdout,
|
19 |
+
)
|
20 |
+
logger = logging.getLogger("dump_feature")
|
21 |
+
|
22 |
+
|
23 |
+
def main(
|
24 |
+
model_type: str,
|
25 |
+
tsv_path: str,
|
26 |
+
ckpt_path: str,
|
27 |
+
whisper_root: str,
|
28 |
+
whisper_name: str,
|
29 |
+
layer: int,
|
30 |
+
nshard: int,
|
31 |
+
rank: int,
|
32 |
+
feat_dir: str,
|
33 |
+
max_chunk: int,
|
34 |
+
use_cpu: bool = False
|
35 |
+
):
|
36 |
+
device = "cpu" if use_cpu else "cuda"
|
37 |
+
|
38 |
+
# some checks
|
39 |
+
if model_type in ["hubert", "data2vec"]:
|
40 |
+
assert ckpt_path and os.path.exists(ckpt_path)
|
41 |
+
elif model_type in ["whisper"]:
|
42 |
+
assert whisper_name and whisper_root
|
43 |
+
else:
|
44 |
+
raise ValueError(f"Unsupported model type {model_type}")
|
45 |
+
|
46 |
+
reader = None
|
47 |
+
if model_type == "hubert":
|
48 |
+
from hubert_feature_reader import HubertFeatureReader
|
49 |
+
reader = HubertFeatureReader(ckpt_path, layer, device=device, max_chunk=max_chunk)
|
50 |
+
elif model_type == "data2vec":
|
51 |
+
from data2vec_feature_reader import Data2vecFeatureReader
|
52 |
+
reader = Data2vecFeatureReader(ckpt_path, layer, device=device, max_chunk=max_chunk)
|
53 |
+
elif model_type == "whisper":
|
54 |
+
from whisper_feature_reader import WhisperFeatureReader
|
55 |
+
reader = WhisperFeatureReader(whisper_root, whisper_name, layer, device=device)
|
56 |
+
|
57 |
+
assert reader is not None
|
58 |
+
|
59 |
+
generator, num = get_path_iterator(tsv_path, nshard, rank)
|
60 |
+
dump_feature(reader, generator, num, nshard, rank, feat_dir)
|
61 |
+
|
62 |
+
|
63 |
+
if __name__ == "__main__":
|
64 |
+
import argparse
|
65 |
+
|
66 |
+
parser = argparse.ArgumentParser()
|
67 |
+
parser.add_argument(
|
68 |
+
"--model_type",
|
69 |
+
required=True,
|
70 |
+
type=str,
|
71 |
+
choices=["data2vec", "hubert", "whisper"],
|
72 |
+
help="the type of the speech encoder."
|
73 |
+
)
|
74 |
+
parser.add_argument(
|
75 |
+
"--tsv_path",
|
76 |
+
required=True,
|
77 |
+
type=str,
|
78 |
+
help="the path to the tsv file."
|
79 |
+
)
|
80 |
+
parser.add_argument(
|
81 |
+
"--ckpt_path",
|
82 |
+
required=False,
|
83 |
+
type=str,
|
84 |
+
default=None,
|
85 |
+
help="path to the speech model. must provide for HuBERT and data2vec"
|
86 |
+
)
|
87 |
+
parser.add_argument(
|
88 |
+
"--whisper_root",
|
89 |
+
required=False,
|
90 |
+
type=str,
|
91 |
+
default=None,
|
92 |
+
help="root dir to download/store whisper model. must provide for whisper model."
|
93 |
+
)
|
94 |
+
parser.add_argument(
|
95 |
+
"--whisper_name",
|
96 |
+
required=False,
|
97 |
+
type=str,
|
98 |
+
default=None,
|
99 |
+
help="name of whisper model. e.g., large-v2. must provide for whisper model."
|
100 |
+
)
|
101 |
+
parser.add_argument(
|
102 |
+
"--layer",
|
103 |
+
required=True,
|
104 |
+
type=int,
|
105 |
+
help="which layer of the model. this is 1-based."
|
106 |
+
)
|
107 |
+
parser.add_argument(
|
108 |
+
"--feat_dir",
|
109 |
+
required=True,
|
110 |
+
type=str,
|
111 |
+
help="the output dir to save the representations."
|
112 |
+
)
|
113 |
+
parser.add_argument(
|
114 |
+
"--nshard",
|
115 |
+
required=False,
|
116 |
+
type=int,
|
117 |
+
default=1,
|
118 |
+
help="total number of shards."
|
119 |
+
)
|
120 |
+
parser.add_argument(
|
121 |
+
"--rank",
|
122 |
+
required=False,
|
123 |
+
type=int,
|
124 |
+
default=0,
|
125 |
+
help="shard id of this process."
|
126 |
+
)
|
127 |
+
parser.add_argument(
|
128 |
+
"--max_chunk",
|
129 |
+
type=int,
|
130 |
+
default=1600000,
|
131 |
+
help="max number of frames of each batch."
|
132 |
+
)
|
133 |
+
parser.add_argument(
|
134 |
+
"--use_cpu",
|
135 |
+
default=False,
|
136 |
+
action="store_true",
|
137 |
+
help="whether use cpu instead of gpu."
|
138 |
+
)
|
139 |
+
args = parser.parse_args()
|
140 |
+
logger.info(args)
|
141 |
+
|
142 |
+
main(**vars(args))
|
RepCodec/examples/feature_utils.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on fairseq (https://github.com/facebookresearch/fairseq)
|
7 |
+
|
8 |
+
# ref: https://github.com/facebookresearch/fairseq/blob/main/examples/hubert/simple_kmeans/feature_utils.py
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import sys
|
13 |
+
|
14 |
+
import tqdm
|
15 |
+
from npy_append_array import NpyAppendArray
|
16 |
+
|
17 |
+
|
18 |
+
logging.basicConfig(
|
19 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
20 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
21 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
22 |
+
stream=sys.stdout,
|
23 |
+
)
|
24 |
+
logger = logging.getLogger("feature_utils")
|
25 |
+
|
26 |
+
|
27 |
+
def get_shard_range(tot, nshard, rank):
|
28 |
+
assert rank < nshard and rank >= 0, f"invaid rank/nshard {rank}/{nshard}"
|
29 |
+
start = round(tot / nshard * rank)
|
30 |
+
end = round(tot / nshard * (rank + 1))
|
31 |
+
assert start < end, f"start={start}, end={end}"
|
32 |
+
logger.info(
|
33 |
+
f"rank {rank} of {nshard}, process {end-start} "
|
34 |
+
f"({start}-{end}) out of {tot}"
|
35 |
+
)
|
36 |
+
return start, end
|
37 |
+
|
38 |
+
|
39 |
+
def get_path_iterator(tsv, nshard, rank):
|
40 |
+
with open(tsv, "r") as f:
|
41 |
+
root = f.readline().rstrip()
|
42 |
+
lines = [line.rstrip() for line in f]
|
43 |
+
start, end = get_shard_range(len(lines), nshard, rank)
|
44 |
+
lines = lines[start:end]
|
45 |
+
def iterate():
|
46 |
+
for line in lines:
|
47 |
+
subpath, nsample = line.split("\t")
|
48 |
+
yield f"{root}/{subpath}", int(nsample)
|
49 |
+
return iterate, len(lines)
|
50 |
+
|
51 |
+
|
52 |
+
def dump_feature(reader, generator, num, nshard, rank, feat_dir):
|
53 |
+
iterator = generator()
|
54 |
+
|
55 |
+
feat_path = f"{feat_dir}/{rank}_{nshard}.npy"
|
56 |
+
leng_path = f"{feat_dir}/{rank}_{nshard}.len"
|
57 |
+
|
58 |
+
os.makedirs(feat_dir, exist_ok=True)
|
59 |
+
if os.path.exists(feat_path):
|
60 |
+
os.remove(feat_path)
|
61 |
+
|
62 |
+
feat_f = NpyAppendArray(feat_path)
|
63 |
+
with open(leng_path, "w") as leng_f:
|
64 |
+
for path, nsample in tqdm.tqdm(iterator, total=num):
|
65 |
+
feat = reader.get_feats(path, nsample)
|
66 |
+
feat_f.append(feat.cpu().numpy())
|
67 |
+
leng_f.write(f"{len(feat)}\n")
|
68 |
+
logger.info("finished successfully")
|
69 |
+
|
70 |
+
|
RepCodec/examples/hubert_feature_reader.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on fairseq (https://github.com/facebookresearch/fairseq)
|
7 |
+
|
8 |
+
import logging
|
9 |
+
|
10 |
+
import fairseq
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
from fairseq.data.audio.audio_utils import get_features_or_waveform
|
15 |
+
|
16 |
+
logger = logging.getLogger("dump_feature")
|
17 |
+
|
18 |
+
|
19 |
+
class HubertFeatureReader(object):
|
20 |
+
def __init__(self, ckpt_path: str, layer: int, device: str, max_chunk=1600000):
|
21 |
+
(
|
22 |
+
model,
|
23 |
+
cfg,
|
24 |
+
task,
|
25 |
+
) = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
|
26 |
+
|
27 |
+
self.device = device
|
28 |
+
logger.info(f"device = {self.device}")
|
29 |
+
|
30 |
+
self.model = model[0].eval().to(self.device)
|
31 |
+
self.task = task
|
32 |
+
self.layer = layer
|
33 |
+
self.max_chunk = max_chunk
|
34 |
+
logger.info(f"TASK CONFIG:\n{self.task.cfg}")
|
35 |
+
logger.info(f" max_chunk = {self.max_chunk}")
|
36 |
+
|
37 |
+
def read_audio(self, path, ref_len=None):
|
38 |
+
wav = get_features_or_waveform(path, need_waveform=True, use_sample_rate=self.task.cfg.sample_rate)
|
39 |
+
if wav.ndim == 2:
|
40 |
+
wav = wav.mean(-1)
|
41 |
+
assert wav.ndim == 1, wav.ndim
|
42 |
+
if ref_len is not None and abs(ref_len - len(wav)) > 160:
|
43 |
+
logger.warning(f"ref {ref_len} != read {len(wav)} ({path})")
|
44 |
+
return wav
|
45 |
+
|
46 |
+
def get_feats(self, path, ref_len=None):
|
47 |
+
x = self.read_audio(path, ref_len=ref_len)
|
48 |
+
with torch.no_grad():
|
49 |
+
x = torch.from_numpy(x).float().to(self.device)
|
50 |
+
if self.task.cfg.normalize:
|
51 |
+
x = F.layer_norm(x, x.shape)
|
52 |
+
x = x.view(1, -1)
|
53 |
+
|
54 |
+
feat = []
|
55 |
+
for start in range(0, x.size(1), self.max_chunk):
|
56 |
+
x_chunk = x[:, start: start + self.max_chunk]
|
57 |
+
feat_chunk, _ = self.model.extract_features(
|
58 |
+
source=x_chunk,
|
59 |
+
padding_mask=None,
|
60 |
+
mask=False,
|
61 |
+
output_layer=self.layer,
|
62 |
+
)
|
63 |
+
feat.append(feat_chunk)
|
64 |
+
return torch.cat(feat, 1).squeeze(0)
|
RepCodec/examples/tokens/data2vec_base_l6_dev-clean.tokens
ADDED
The diff for this file is too large to render.
See raw diff
|
|
RepCodec/examples/tokens/data2vec_large_l18_dev-clean.tokens
ADDED
The diff for this file is too large to render.
See raw diff
|
|
RepCodec/examples/tokens/hubert_base_l9_dev-clean.tokens
ADDED
The diff for this file is too large to render.
See raw diff
|
|
RepCodec/examples/tokens/hubert_large_l18_dev-clean.tokens
ADDED
The diff for this file is too large to render.
See raw diff
|
|
RepCodec/examples/tokens/whisper_large_l32_dev-clean.tokens
ADDED
The diff for this file is too large to render.
See raw diff
|
|
RepCodec/examples/tokens/whisper_medium_l24_dev-clean.tokens
ADDED
The diff for this file is too large to render.
See raw diff
|
|
RepCodec/examples/whisper_feature_reader.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on fairseq (https://github.com/facebookresearch/fairseq) and
|
7 |
+
# Whisper (https://github.com/openai/whisper/)
|
8 |
+
|
9 |
+
import io
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
from typing import Optional, Union
|
13 |
+
|
14 |
+
import soundfile as sf
|
15 |
+
import torch
|
16 |
+
from whisper import _MODELS, _download, _ALIGNMENT_HEADS, available_models
|
17 |
+
from whisper.audio import log_mel_spectrogram
|
18 |
+
from whisper.model import ModelDimensions
|
19 |
+
|
20 |
+
from whisper_model import Whisper_
|
21 |
+
|
22 |
+
logger = logging.getLogger("dump_feature")
|
23 |
+
|
24 |
+
|
25 |
+
def load_model(
|
26 |
+
name: str,
|
27 |
+
device: Optional[Union[str, torch.device]] = None,
|
28 |
+
download_root: str = None,
|
29 |
+
in_memory: bool = False,
|
30 |
+
) -> Whisper_:
|
31 |
+
"""
|
32 |
+
Reference: https://github.com/openai/whisper/blob/main/whisper/__init__.py#L97
|
33 |
+
But we will load a `Whisper_` model for feature extraction.
|
34 |
+
|
35 |
+
Parameters
|
36 |
+
----------
|
37 |
+
name : str
|
38 |
+
one of the official model names listed by `whisper.available_models()`, or
|
39 |
+
path to a model checkpoint containing the model dimensions and the model state_dict.
|
40 |
+
device : Union[str, torch.device]
|
41 |
+
the PyTorch device to put the model into
|
42 |
+
download_root: str
|
43 |
+
path to download the model files; by default, it uses "~/.cache/whisper"
|
44 |
+
in_memory: bool
|
45 |
+
whether to preload the model weights into host memory
|
46 |
+
|
47 |
+
Returns
|
48 |
+
-------
|
49 |
+
model : Whisper
|
50 |
+
The Whisper ASR model instance
|
51 |
+
"""
|
52 |
+
|
53 |
+
if device is None:
|
54 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
55 |
+
if download_root is None:
|
56 |
+
default = os.path.join(os.path.expanduser("~"), ".cache")
|
57 |
+
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
58 |
+
|
59 |
+
if name in _MODELS:
|
60 |
+
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
61 |
+
alignment_heads = _ALIGNMENT_HEADS[name]
|
62 |
+
elif os.path.isfile(name):
|
63 |
+
checkpoint_file = open(name, "rb").read() if in_memory else name
|
64 |
+
alignment_heads = None
|
65 |
+
else:
|
66 |
+
raise RuntimeError(
|
67 |
+
f"Model {name} not found; available models = {available_models()}"
|
68 |
+
)
|
69 |
+
|
70 |
+
with (
|
71 |
+
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
72 |
+
) as fp:
|
73 |
+
checkpoint = torch.load(fp, map_location=device)
|
74 |
+
del checkpoint_file
|
75 |
+
|
76 |
+
dims = ModelDimensions(**checkpoint["dims"])
|
77 |
+
model = Whisper_(dims)
|
78 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
79 |
+
|
80 |
+
if alignment_heads is not None:
|
81 |
+
model.set_alignment_heads(alignment_heads)
|
82 |
+
|
83 |
+
return model.to(device)
|
84 |
+
|
85 |
+
|
86 |
+
class WhisperFeatureReader(object):
|
87 |
+
def __init__(self, root, ckpt, layer, device):
|
88 |
+
self.device = device
|
89 |
+
logger.info(f"device = {self.device}")
|
90 |
+
|
91 |
+
self.model: Whisper_ = load_model(name=ckpt, device=self.device, download_root=root).eval()
|
92 |
+
self.model.decoder = None # to save some memory by deleting the decoder
|
93 |
+
self.layer = layer # one-based
|
94 |
+
|
95 |
+
def read_audio(self, path, ref_len=None):
|
96 |
+
wav, sample_rate = sf.read(path)
|
97 |
+
assert sample_rate == 16000, sample_rate
|
98 |
+
if ref_len is not None and abs(ref_len - len(wav)) > 160:
|
99 |
+
logger.warning(f"ref {ref_len} != read {len(wav)} ({path})")
|
100 |
+
return wav
|
101 |
+
|
102 |
+
def get_feats(self, path, ref_len=None):
|
103 |
+
wav = self.read_audio(path, ref_len)
|
104 |
+
audio_length = len(wav)
|
105 |
+
with torch.no_grad():
|
106 |
+
mel = log_mel_spectrogram(torch.from_numpy(wav).float().to(self.device))
|
107 |
+
hidden = self.model.extract_features(mel.unsqueeze(0), target_layer=self.layer)
|
108 |
+
feature_length = audio_length // 320
|
109 |
+
hidden = hidden[0, :feature_length]
|
110 |
+
return hidden.contiguous()
|
RepCodec/examples/whisper_model.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on fairseq (https://github.com/facebookresearch/fairseq) and
|
7 |
+
# Whisper (https://github.com/openai/whisper/)
|
8 |
+
|
9 |
+
from typing import Optional
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch import Tensor
|
14 |
+
from whisper.model import AudioEncoder, sinusoids, Whisper, ModelDimensions
|
15 |
+
|
16 |
+
|
17 |
+
class AudioEncoder_(AudioEncoder):
|
18 |
+
def __init__(self, *args, **kwargs):
|
19 |
+
super(AudioEncoder_, self).__init__(*args, **kwargs)
|
20 |
+
|
21 |
+
def extract_feature(self, x: Tensor, target_layer: Optional[int] = None):
|
22 |
+
"""
|
23 |
+
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
24 |
+
the mel spectrogram of the audio
|
25 |
+
"""
|
26 |
+
x = F.gelu(self.conv1(x))
|
27 |
+
x = F.gelu(self.conv2(x))
|
28 |
+
x = x.permute(0, 2, 1)
|
29 |
+
|
30 |
+
length_x = x.shape[1]
|
31 |
+
if length_x > self.positional_embedding.shape[0]:
|
32 |
+
self.register_buffer("positional_embedding", sinusoids(length_x, self.positional_embedding.shape[1]))
|
33 |
+
self.positional_embedding = self.positional_embedding.to(x.device)
|
34 |
+
x = (x + self.positional_embedding[:length_x, :]).to(x.dtype)
|
35 |
+
|
36 |
+
if target_layer is None:
|
37 |
+
target_layer = len(self.blocks)
|
38 |
+
|
39 |
+
for block in self.blocks[:target_layer]:
|
40 |
+
x = block(x)
|
41 |
+
|
42 |
+
return x
|
43 |
+
|
44 |
+
|
45 |
+
class Whisper_(Whisper):
|
46 |
+
def __init__(self, dims: ModelDimensions):
|
47 |
+
super(Whisper_, self).__init__(dims)
|
48 |
+
# replace audio encoder with our audio encoder
|
49 |
+
self.encoder = AudioEncoder_(
|
50 |
+
self.dims.n_mels,
|
51 |
+
self.dims.n_audio_ctx,
|
52 |
+
self.dims.n_audio_state,
|
53 |
+
self.dims.n_audio_head,
|
54 |
+
self.dims.n_audio_layer,
|
55 |
+
)
|
56 |
+
|
57 |
+
def extract_features(self, mel: torch.Tensor, target_layer: Optional[int] = None):
|
58 |
+
return self.encoder.extract_feature(mel, target_layer)
|
RepCodec/repcodec/RepCodec.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the CC BY-NC license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
|
7 |
+
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from repcodec.modules.decoder import Decoder
|
11 |
+
from repcodec.modules.encoder import Encoder
|
12 |
+
from repcodec.modules.projector import Projector
|
13 |
+
from repcodec.modules.quantizer import Quantizer
|
14 |
+
|
15 |
+
|
16 |
+
class RepCodec(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
input_channels=768,
|
20 |
+
output_channels=768,
|
21 |
+
encode_channels=768,
|
22 |
+
decode_channels=768,
|
23 |
+
code_dim=768,
|
24 |
+
codebook_num=1,
|
25 |
+
codebook_size=1024,
|
26 |
+
bias=True,
|
27 |
+
enc_ratios=(1, 1),
|
28 |
+
dec_ratios=(1, 1),
|
29 |
+
enc_strides=(1, 1),
|
30 |
+
dec_strides=(1, 1),
|
31 |
+
enc_kernel_size=3,
|
32 |
+
dec_kernel_size=3,
|
33 |
+
enc_block_dilations=(1, 1),
|
34 |
+
enc_block_kernel_size=3,
|
35 |
+
dec_block_dilations=(1, 1),
|
36 |
+
dec_block_kernel_size=3
|
37 |
+
):
|
38 |
+
super().__init__()
|
39 |
+
|
40 |
+
self.input_channels = input_channels
|
41 |
+
|
42 |
+
self.encoder = Encoder(
|
43 |
+
input_channels=input_channels,
|
44 |
+
encode_channels=encode_channels,
|
45 |
+
channel_ratios=enc_ratios,
|
46 |
+
strides=enc_strides,
|
47 |
+
kernel_size=enc_kernel_size,
|
48 |
+
bias=bias,
|
49 |
+
block_dilations=enc_block_dilations,
|
50 |
+
unit_kernel_size=enc_block_kernel_size
|
51 |
+
)
|
52 |
+
|
53 |
+
self.decoder = Decoder(
|
54 |
+
code_dim=code_dim,
|
55 |
+
output_channels=output_channels,
|
56 |
+
decode_channels=decode_channels,
|
57 |
+
channel_ratios=dec_ratios,
|
58 |
+
strides=dec_strides,
|
59 |
+
kernel_size=dec_kernel_size,
|
60 |
+
bias=bias,
|
61 |
+
block_dilations=dec_block_dilations,
|
62 |
+
unit_kernel_size=dec_block_kernel_size
|
63 |
+
)
|
64 |
+
|
65 |
+
self.projector = Projector(
|
66 |
+
input_channels=self.encoder.out_channels,
|
67 |
+
code_dim=code_dim,
|
68 |
+
kernel_size=3,
|
69 |
+
stride=1,
|
70 |
+
bias=False
|
71 |
+
)
|
72 |
+
|
73 |
+
self.quantizer = Quantizer(
|
74 |
+
code_dim=code_dim,
|
75 |
+
codebook_num=codebook_num,
|
76 |
+
codebook_size=codebook_size
|
77 |
+
)
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
x = self.encoder(x)
|
81 |
+
z = self.projector(x)
|
82 |
+
zq, vqloss, perplexity = self.quantizer(z)
|
83 |
+
y = self.decoder(zq)
|
84 |
+
return y, zq, z, vqloss, perplexity
|
RepCodec/repcodec/configs/repcodec_dim1024.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
input_channels: 1024
|
2 |
+
output_channels: 1024
|
3 |
+
encode_channels: 1024
|
4 |
+
decode_channels: 1024
|
5 |
+
code_dim: 1024
|
6 |
+
codebook_num: 1
|
7 |
+
codebook_size: 1024
|
8 |
+
bias: true
|
9 |
+
enc_ratios: [ 1, 1 ]
|
10 |
+
dec_ratios: [ 1, 1 ]
|
11 |
+
enc_strides: [ 1, 1 ] # no downsampling
|
12 |
+
dec_strides: [ 1, 1 ]
|
13 |
+
enc_kernel_size: 3
|
14 |
+
dec_kernel_size: 3
|
15 |
+
enc_block_dilations: [ 1, 1 ]
|
16 |
+
enc_block_kernel_size: 3
|
17 |
+
dec_block_dilations: [ 1, 1 ]
|
18 |
+
dec_block_kernel_size: 3
|
RepCodec/repcodec/configs/repcodec_dim1280.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
input_channels: 1280
|
2 |
+
output_channels: 1280
|
3 |
+
encode_channels: 1280
|
4 |
+
decode_channels: 1280
|
5 |
+
code_dim: 1280
|
6 |
+
codebook_num: 1
|
7 |
+
codebook_size: 1024
|
8 |
+
bias: true
|
9 |
+
enc_ratios: [ 1, 1 ]
|
10 |
+
dec_ratios: [ 1, 1 ]
|
11 |
+
enc_strides: [ 1, 1 ] # no downsampling
|
12 |
+
dec_strides: [ 1, 1 ]
|
13 |
+
enc_kernel_size: 3
|
14 |
+
dec_kernel_size: 3
|
15 |
+
enc_block_dilations: [ 1, 1 ]
|
16 |
+
enc_block_kernel_size: 3
|
17 |
+
dec_block_dilations: [ 1, 1 ]
|
18 |
+
dec_block_kernel_size: 3
|
RepCodec/repcodec/configs/repcodec_dim768.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
input_channels: 768
|
2 |
+
output_channels: 768
|
3 |
+
encode_channels: 768
|
4 |
+
decode_channels: 768
|
5 |
+
code_dim: 768
|
6 |
+
codebook_num: 1
|
7 |
+
codebook_size: 1024
|
8 |
+
bias: true
|
9 |
+
enc_ratios: [ 1, 1 ]
|
10 |
+
dec_ratios: [ 1, 1 ]
|
11 |
+
enc_strides: [ 1, 1 ] # no downsampling
|
12 |
+
dec_strides: [ 1, 1 ]
|
13 |
+
enc_kernel_size: 3
|
14 |
+
dec_kernel_size: 3
|
15 |
+
enc_block_dilations: [ 1, 1 ]
|
16 |
+
enc_block_kernel_size: 3
|
17 |
+
dec_block_dilations: [ 1, 1 ]
|
18 |
+
dec_block_kernel_size: 3
|
RepCodec/repcodec/layers/conv_layer.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the CC BY-NC license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
|
7 |
+
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
|
11 |
+
class Conv1d1x1(nn.Conv1d):
|
12 |
+
"""1x1 Conv1d."""
|
13 |
+
|
14 |
+
def __init__(self, in_channels, out_channels, bias=True):
|
15 |
+
super(Conv1d1x1, self).__init__(in_channels, out_channels, kernel_size=1, bias=bias)
|
16 |
+
|
17 |
+
|
18 |
+
class Conv1d(nn.Module):
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
in_channels: int,
|
22 |
+
out_channels: int,
|
23 |
+
kernel_size: int,
|
24 |
+
stride: int = 1,
|
25 |
+
padding: int = -1,
|
26 |
+
dilation: int = 1,
|
27 |
+
groups: int = 1,
|
28 |
+
bias: bool = True
|
29 |
+
):
|
30 |
+
super().__init__()
|
31 |
+
self.in_channels = in_channels
|
32 |
+
self.out_channels = out_channels
|
33 |
+
self.kernel_size = kernel_size
|
34 |
+
if padding < 0:
|
35 |
+
padding = (kernel_size - 1) // 2 * dilation
|
36 |
+
self.dilation = dilation
|
37 |
+
self.conv = nn.Conv1d(
|
38 |
+
in_channels=in_channels,
|
39 |
+
out_channels=out_channels,
|
40 |
+
kernel_size=kernel_size,
|
41 |
+
stride=stride,
|
42 |
+
padding=padding,
|
43 |
+
dilation=dilation,
|
44 |
+
groups=groups,
|
45 |
+
bias=bias,
|
46 |
+
)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
"""
|
50 |
+
Args:
|
51 |
+
x (Tensor): Float tensor variable with the shape (B, C, T).
|
52 |
+
Returns:
|
53 |
+
Tensor: Float tensor variable with the shape (B, C, T).
|
54 |
+
"""
|
55 |
+
x = self.conv(x)
|
56 |
+
return x
|
57 |
+
|
58 |
+
|
59 |
+
class ConvTranspose1d(nn.Module):
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
in_channels: int,
|
63 |
+
out_channels: int,
|
64 |
+
kernel_size: int,
|
65 |
+
stride: int,
|
66 |
+
padding=-1,
|
67 |
+
output_padding=-1,
|
68 |
+
groups=1,
|
69 |
+
bias=True,
|
70 |
+
):
|
71 |
+
super().__init__()
|
72 |
+
if padding < 0:
|
73 |
+
padding = (stride + 1) // 2
|
74 |
+
if output_padding < 0:
|
75 |
+
output_padding = 1 if stride % 2 else 0
|
76 |
+
self.deconv = nn.ConvTranspose1d(
|
77 |
+
in_channels=in_channels,
|
78 |
+
out_channels=out_channels,
|
79 |
+
kernel_size=kernel_size,
|
80 |
+
stride=stride,
|
81 |
+
padding=padding,
|
82 |
+
output_padding=output_padding,
|
83 |
+
groups=groups,
|
84 |
+
bias=bias,
|
85 |
+
)
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
"""
|
89 |
+
Args:
|
90 |
+
x (Tensor): Float tensor variable with the shape (B, C, T).
|
91 |
+
Returns:
|
92 |
+
Tensor: Float tensor variable with the shape (B, C', T').
|
93 |
+
"""
|
94 |
+
x = self.deconv(x)
|
95 |
+
return x
|
RepCodec/repcodec/layers/vq_module.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the CC BY-NC license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
|
13 |
+
class VectorQuantize(nn.Module):
|
14 |
+
"""Vector quantization w/ exponential moving averages (EMA)"""
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
dim: int,
|
19 |
+
codebook_size: int,
|
20 |
+
decay=0.8,
|
21 |
+
commitment=1.,
|
22 |
+
eps=1e-5,
|
23 |
+
n_embed=None,
|
24 |
+
):
|
25 |
+
super().__init__()
|
26 |
+
n_embed = self.default(n_embed, codebook_size)
|
27 |
+
|
28 |
+
self.dim = dim
|
29 |
+
self.n_embed = n_embed
|
30 |
+
self.decay = decay
|
31 |
+
self.eps = eps
|
32 |
+
self.commitment = commitment
|
33 |
+
|
34 |
+
embed = torch.randn(dim, n_embed)
|
35 |
+
self.register_buffer('embed', embed)
|
36 |
+
self.register_buffer('cluster_size', torch.zeros(n_embed))
|
37 |
+
self.register_buffer('embed_avg', embed.clone())
|
38 |
+
|
39 |
+
@property
|
40 |
+
def codebook(self):
|
41 |
+
return self.embed.transpose(0, 1)
|
42 |
+
|
43 |
+
def exists(self, val):
|
44 |
+
return val is not None
|
45 |
+
|
46 |
+
def default(self, val, d):
|
47 |
+
return val if self.exists(val) else d
|
48 |
+
|
49 |
+
def ema_inplace(self, moving_avg, new, decay):
|
50 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
51 |
+
|
52 |
+
def laplace_smoothing(self, x, n_categories, eps=1e-5):
|
53 |
+
return (x + eps) / (x.sum() + n_categories * eps)
|
54 |
+
|
55 |
+
def forward(self, input):
|
56 |
+
dtype = input.dtype
|
57 |
+
flatten = input.reshape(-1, self.dim)
|
58 |
+
dist = (
|
59 |
+
flatten.pow(2).sum(1, keepdim=True)
|
60 |
+
- 2 * flatten @ self.embed
|
61 |
+
+ self.embed.pow(2).sum(0, keepdim=True)
|
62 |
+
)
|
63 |
+
_, embed_ind = (-dist).max(1)
|
64 |
+
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(dtype)
|
65 |
+
embed_ind = embed_ind.view(*input.shape[:-1])
|
66 |
+
quantize = F.embedding(embed_ind, self.embed.transpose(0, 1))
|
67 |
+
|
68 |
+
if self.training:
|
69 |
+
self.ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
70 |
+
embed_sum = flatten.transpose(0, 1) @ embed_onehot
|
71 |
+
self.ema_inplace(self.embed_avg, embed_sum, self.decay)
|
72 |
+
cluster_size = self.laplace_smoothing(self.cluster_size, self.n_embed, self.eps) * self.cluster_size.sum()
|
73 |
+
embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
|
74 |
+
self.embed.data.copy_(embed_normalized)
|
75 |
+
|
76 |
+
loss = F.mse_loss(quantize.detach(), input) * self.commitment
|
77 |
+
quantize = input + (quantize - input).detach()
|
78 |
+
|
79 |
+
avg_probs = torch.mean(embed_onehot, dim=0)
|
80 |
+
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
|
81 |
+
|
82 |
+
return quantize, loss, perplexity
|
83 |
+
|
84 |
+
def forward_index(self, input):
|
85 |
+
dtype = input.dtype
|
86 |
+
flatten = input.reshape(-1, self.dim)
|
87 |
+
dist = (
|
88 |
+
flatten.pow(2).sum(1, keepdim=True)
|
89 |
+
- 2 * flatten @ self.embed
|
90 |
+
+ self.embed.pow(2).sum(0, keepdim=True)
|
91 |
+
)
|
92 |
+
_, embed_ind = (-dist).max(1)
|
93 |
+
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(dtype)
|
94 |
+
embed_ind = embed_ind.view(*input.shape[:-1])
|
95 |
+
quantize = F.embedding(embed_ind, self.embed.transpose(0, 1))
|
96 |
+
quantize = input + (quantize - input).detach()
|
97 |
+
|
98 |
+
return quantize, embed_ind
|
99 |
+
|
100 |
+
|
101 |
+
class ResidualVQ(nn.Module):
|
102 |
+
""" Residual VQ following algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """
|
103 |
+
|
104 |
+
def __init__(
|
105 |
+
self,
|
106 |
+
*,
|
107 |
+
num_quantizers,
|
108 |
+
**kwargs
|
109 |
+
):
|
110 |
+
super().__init__()
|
111 |
+
self.layers = nn.ModuleList([VectorQuantize(**kwargs) for _ in range(num_quantizers)])
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
quantized_out = 0.
|
115 |
+
residual = x
|
116 |
+
all_losses = []
|
117 |
+
all_perplexities = []
|
118 |
+
for layer in self.layers:
|
119 |
+
quantized, loss, perplexity = layer(residual)
|
120 |
+
# Issue: https://github.com/lucidrains/vector-quantize-pytorch/issues/33
|
121 |
+
# We found considering only the 1st layer VQ's graident results in better performance
|
122 |
+
# residual = residual - quantized.detach() # considering all layers' graidents
|
123 |
+
residual = residual - quantized # considering only the first layer's graident
|
124 |
+
quantized_out = quantized_out + quantized
|
125 |
+
all_losses.append(loss)
|
126 |
+
all_perplexities.append(perplexity)
|
127 |
+
all_losses, all_perplexities = map(torch.stack, (all_losses, all_perplexities))
|
128 |
+
return quantized_out, all_losses, all_perplexities
|
129 |
+
|
130 |
+
def forward_index(self, x, flatten_idx=False):
|
131 |
+
quantized_out = 0.
|
132 |
+
residual = x
|
133 |
+
all_indices = []
|
134 |
+
for i, layer in enumerate(self.layers):
|
135 |
+
quantized, indices = layer.forward_index(residual)
|
136 |
+
# residual = residual - quantized.detach()
|
137 |
+
residual = residual - quantized
|
138 |
+
quantized_out = quantized_out + quantized
|
139 |
+
if flatten_idx:
|
140 |
+
indices += (self.codebook_size * i)
|
141 |
+
all_indices.append(indices)
|
142 |
+
all_indices = torch.stack(all_indices)
|
143 |
+
return quantized_out, all_indices.squeeze(1)
|
144 |
+
|
145 |
+
def initial(self):
|
146 |
+
self.codebook = []
|
147 |
+
for layer in self.layers:
|
148 |
+
self.codebook.append(layer.codebook)
|
149 |
+
self.codebook_size = self.codebook[0].size(0)
|
150 |
+
self.codebook = torch.stack(self.codebook)
|
151 |
+
self.codebook = self.codebook.reshape(-1, self.codebook.size(-1))
|
152 |
+
|
153 |
+
def lookup(self, indices):
|
154 |
+
quantized_out = F.embedding(indices, self.codebook) # Num x T x C
|
155 |
+
return torch.sum(quantized_out, dim=0, keepdim=True)
|
RepCodec/repcodec/modules/decoder.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the CC BY-NC license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from RepCodec.repcodec.layers.conv_layer import Conv1d, ConvTranspose1d
|
12 |
+
from RepCodec.repcodec.modules.residual_unit import ResidualUnit
|
13 |
+
|
14 |
+
|
15 |
+
class DecoderBlock(nn.Module):
|
16 |
+
""" Decoder block (no up-sampling) """
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
in_channels: int,
|
21 |
+
out_channels: int,
|
22 |
+
stride: int,
|
23 |
+
dilations=(1, 1),
|
24 |
+
unit_kernel_size=3,
|
25 |
+
bias=True
|
26 |
+
):
|
27 |
+
super().__init__()
|
28 |
+
|
29 |
+
if stride == 1:
|
30 |
+
self.conv = Conv1d(
|
31 |
+
in_channels=in_channels,
|
32 |
+
out_channels=out_channels,
|
33 |
+
kernel_size=3, # fix kernel=3 when stride=1 for unchanged shape
|
34 |
+
stride=stride,
|
35 |
+
bias=bias,
|
36 |
+
)
|
37 |
+
else:
|
38 |
+
self.conv = ConvTranspose1d(
|
39 |
+
in_channels=in_channels,
|
40 |
+
out_channels=out_channels,
|
41 |
+
kernel_size=(2 * stride),
|
42 |
+
stride=stride,
|
43 |
+
bias=bias,
|
44 |
+
)
|
45 |
+
|
46 |
+
self.res_units = torch.nn.ModuleList()
|
47 |
+
for idx, dilation in enumerate(dilations):
|
48 |
+
self.res_units += [
|
49 |
+
ResidualUnit(out_channels, out_channels,
|
50 |
+
kernel_size=unit_kernel_size,
|
51 |
+
dilation=dilation)
|
52 |
+
]
|
53 |
+
self.num_res = len(self.res_units)
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
x = self.conv(x)
|
57 |
+
for idx in range(self.num_res):
|
58 |
+
x = self.res_units[idx](x)
|
59 |
+
return x
|
60 |
+
|
61 |
+
|
62 |
+
class Decoder(nn.Module):
|
63 |
+
def __init__(
|
64 |
+
self,
|
65 |
+
code_dim: int,
|
66 |
+
output_channels: int,
|
67 |
+
decode_channels: int,
|
68 |
+
channel_ratios=(1, 1),
|
69 |
+
strides=(1, 1),
|
70 |
+
kernel_size=3,
|
71 |
+
bias=True,
|
72 |
+
block_dilations=(1, 1),
|
73 |
+
unit_kernel_size=3,
|
74 |
+
):
|
75 |
+
super().__init__()
|
76 |
+
assert len(channel_ratios) == len(strides)
|
77 |
+
|
78 |
+
self.conv1 = Conv1d(
|
79 |
+
in_channels=code_dim,
|
80 |
+
out_channels=int(decode_channels * channel_ratios[0]),
|
81 |
+
kernel_size=kernel_size,
|
82 |
+
stride=1,
|
83 |
+
bias=False
|
84 |
+
)
|
85 |
+
|
86 |
+
self.conv_blocks = torch.nn.ModuleList()
|
87 |
+
for idx, stride in enumerate(strides):
|
88 |
+
in_channels = int(decode_channels * channel_ratios[idx])
|
89 |
+
if idx < (len(channel_ratios) - 1):
|
90 |
+
out_channels = int(decode_channels * channel_ratios[idx + 1])
|
91 |
+
else:
|
92 |
+
out_channels = decode_channels
|
93 |
+
self.conv_blocks += [
|
94 |
+
DecoderBlock(
|
95 |
+
in_channels, out_channels, stride,
|
96 |
+
dilations=block_dilations, unit_kernel_size=unit_kernel_size,
|
97 |
+
bias=bias
|
98 |
+
)
|
99 |
+
]
|
100 |
+
self.num_blocks = len(self.conv_blocks)
|
101 |
+
|
102 |
+
self.conv2 = Conv1d(out_channels, output_channels, kernel_size, 1, bias=False)
|
103 |
+
|
104 |
+
def forward(self, z):
|
105 |
+
x = self.conv1(z)
|
106 |
+
for i in range(self.num_blocks):
|
107 |
+
x = self.conv_blocks[i](x)
|
108 |
+
x = self.conv2(x)
|
109 |
+
return x
|
RepCodec/repcodec/modules/encoder.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the CC BY-NC license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from RepCodec.repcodec.layers.conv_layer import Conv1d
|
12 |
+
from RepCodec.repcodec.modules.residual_unit import ResidualUnit
|
13 |
+
|
14 |
+
|
15 |
+
class EncoderBlock(nn.Module):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
in_channels: int,
|
19 |
+
out_channels: int,
|
20 |
+
stride: int,
|
21 |
+
dilations=(1, 1),
|
22 |
+
unit_kernel_size=3,
|
23 |
+
bias=True
|
24 |
+
):
|
25 |
+
super().__init__()
|
26 |
+
self.res_units = torch.nn.ModuleList()
|
27 |
+
for dilation in dilations:
|
28 |
+
self.res_units += [
|
29 |
+
ResidualUnit(in_channels, in_channels,
|
30 |
+
kernel_size=unit_kernel_size,
|
31 |
+
dilation=dilation)
|
32 |
+
]
|
33 |
+
self.num_res = len(self.res_units)
|
34 |
+
|
35 |
+
self.conv = Conv1d(
|
36 |
+
in_channels=in_channels,
|
37 |
+
out_channels=out_channels,
|
38 |
+
kernel_size=3 if stride == 1 else (2 * stride), # special case: stride=1, do not use kernel=2
|
39 |
+
stride=stride,
|
40 |
+
bias=bias,
|
41 |
+
)
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
for idx in range(self.num_res):
|
45 |
+
x = self.res_units[idx](x)
|
46 |
+
x = self.conv(x)
|
47 |
+
return x
|
48 |
+
|
49 |
+
|
50 |
+
class Encoder(nn.Module):
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
input_channels: int,
|
54 |
+
encode_channels: int,
|
55 |
+
channel_ratios=(1, 1),
|
56 |
+
strides=(1, 1),
|
57 |
+
kernel_size=3,
|
58 |
+
bias=True,
|
59 |
+
block_dilations=(1, 1),
|
60 |
+
unit_kernel_size=3
|
61 |
+
):
|
62 |
+
super().__init__()
|
63 |
+
assert len(channel_ratios) == len(strides)
|
64 |
+
|
65 |
+
self.conv = Conv1d(
|
66 |
+
in_channels=input_channels,
|
67 |
+
out_channels=encode_channels,
|
68 |
+
kernel_size=kernel_size,
|
69 |
+
stride=1,
|
70 |
+
bias=False
|
71 |
+
)
|
72 |
+
self.conv_blocks = torch.nn.ModuleList()
|
73 |
+
in_channels = encode_channels
|
74 |
+
for idx, stride in enumerate(strides):
|
75 |
+
out_channels = int(encode_channels * channel_ratios[idx]) # could be float
|
76 |
+
self.conv_blocks += [
|
77 |
+
EncoderBlock(in_channels, out_channels, stride,
|
78 |
+
dilations=block_dilations, unit_kernel_size=unit_kernel_size,
|
79 |
+
bias=bias)
|
80 |
+
]
|
81 |
+
in_channels = out_channels
|
82 |
+
self.num_blocks = len(self.conv_blocks)
|
83 |
+
self.out_channels = out_channels
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
x = self.conv(x)
|
87 |
+
for i in range(self.num_blocks):
|
88 |
+
x = self.conv_blocks[i](x)
|
89 |
+
return x
|
RepCodec/repcodec/modules/projector.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the CC BY-NC license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
|
7 |
+
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from repcodec.layers.conv_layer import Conv1d
|
11 |
+
|
12 |
+
|
13 |
+
class Projector(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
input_channels: int,
|
17 |
+
code_dim: int,
|
18 |
+
kernel_size=3,
|
19 |
+
stride=1,
|
20 |
+
bias=False
|
21 |
+
):
|
22 |
+
super().__init__()
|
23 |
+
self.project = Conv1d(
|
24 |
+
input_channels,
|
25 |
+
code_dim,
|
26 |
+
kernel_size=kernel_size,
|
27 |
+
stride=stride,
|
28 |
+
bias=bias
|
29 |
+
)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
return self.project(x)
|
RepCodec/repcodec/modules/quantizer.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the CC BY-NC license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
|
7 |
+
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from repcodec.layers.vq_module import ResidualVQ
|
11 |
+
|
12 |
+
|
13 |
+
class Quantizer(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
code_dim: int,
|
17 |
+
codebook_num: int,
|
18 |
+
codebook_size: int,
|
19 |
+
):
|
20 |
+
super().__init__()
|
21 |
+
self.codebook = ResidualVQ(
|
22 |
+
dim=code_dim,
|
23 |
+
num_quantizers=codebook_num,
|
24 |
+
codebook_size=codebook_size
|
25 |
+
)
|
26 |
+
|
27 |
+
def initial(self):
|
28 |
+
self.codebook.initial()
|
29 |
+
|
30 |
+
def forward(self, z):
|
31 |
+
zq, vqloss, perplexity = self.codebook(z.transpose(2, 1))
|
32 |
+
zq = zq.transpose(2, 1)
|
33 |
+
return zq, vqloss, perplexity
|
34 |
+
|
35 |
+
def inference(self, z):
|
36 |
+
zq, indices = self.codebook.forward_index(z.transpose(2, 1))
|
37 |
+
zq = zq.transpose(2, 1)
|
38 |
+
return zq, indices
|
39 |
+
|
40 |
+
def encode(self, z):
|
41 |
+
zq, indices = self.codebook.forward_index(z.transpose(2, 1), flatten_idx=True)
|
42 |
+
return zq, indices
|
43 |
+
|
44 |
+
def decode(self, indices):
|
45 |
+
z = self.codebook.lookup(indices)
|
46 |
+
return z
|
RepCodec/repcodec/modules/residual_unit.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the CC BY-NC license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
|
7 |
+
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from RepCodec.repcodec.layers.conv_layer import Conv1d, Conv1d1x1
|
11 |
+
|
12 |
+
|
13 |
+
class ResidualUnit(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
in_channels: int,
|
17 |
+
out_channels: int,
|
18 |
+
kernel_size=3,
|
19 |
+
dilation=1,
|
20 |
+
bias=False,
|
21 |
+
nonlinear_activation="ELU",
|
22 |
+
nonlinear_activation_params={},
|
23 |
+
):
|
24 |
+
super().__init__()
|
25 |
+
self.activation = getattr(nn, nonlinear_activation)(**nonlinear_activation_params)
|
26 |
+
self.conv1 = Conv1d(
|
27 |
+
in_channels=in_channels,
|
28 |
+
out_channels=out_channels,
|
29 |
+
kernel_size=kernel_size,
|
30 |
+
stride=1,
|
31 |
+
dilation=dilation,
|
32 |
+
bias=bias,
|
33 |
+
)
|
34 |
+
self.conv2 = Conv1d1x1(out_channels, out_channels, bias)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
y = self.conv1(self.activation(x))
|
38 |
+
y = self.conv2(self.activation(y))
|
39 |
+
return x + y
|
RepCodec/repcodec/tokenize.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import os
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Tuple, List, Optional
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import yaml
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
from repcodec.RepCodec import RepCodec
|
18 |
+
|
19 |
+
ALL_MODELS = {
|
20 |
+
"data2vec_base_l6": 768,
|
21 |
+
"data2vec_large_l18": 1024,
|
22 |
+
"hubert_base_l9": 768,
|
23 |
+
"hubert_large_l18": 1024,
|
24 |
+
"whisper_medium_l24": 1024,
|
25 |
+
"whisper_large_l32": 1280
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
def parse_args():
|
30 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
31 |
+
parser.add_argument(
|
32 |
+
"in_dir",
|
33 |
+
type=str,
|
34 |
+
help="directory of representations to be tokenized."
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
"--model",
|
38 |
+
required=True,
|
39 |
+
type=str,
|
40 |
+
help="path of the RepCodec model."
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"--tsv_path",
|
44 |
+
required=True,
|
45 |
+
type=str,
|
46 |
+
help="path of the tsv file."
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
"--model_config_path",
|
50 |
+
default=None,
|
51 |
+
type=str,
|
52 |
+
help="please provide this training config if you are using the model you trained yourself."
|
53 |
+
)
|
54 |
+
parser.add_argument(
|
55 |
+
"--n_shard",
|
56 |
+
required=False,
|
57 |
+
type=int,
|
58 |
+
default=1,
|
59 |
+
help="number of shards of representations."
|
60 |
+
)
|
61 |
+
parser.add_argument(
|
62 |
+
"--use_gpu",
|
63 |
+
default=False,
|
64 |
+
action="store_true",
|
65 |
+
help="whether use gpu for inference."
|
66 |
+
)
|
67 |
+
parser.add_argument(
|
68 |
+
"--batch_size",
|
69 |
+
default=1,
|
70 |
+
type=int,
|
71 |
+
help="number of utterances for each mini batch."
|
72 |
+
)
|
73 |
+
parser.add_argument(
|
74 |
+
"--out_dir",
|
75 |
+
type=str,
|
76 |
+
default=".",
|
77 |
+
help="the directory to save the output."
|
78 |
+
)
|
79 |
+
return parser.parse_args()
|
80 |
+
|
81 |
+
|
82 |
+
def load_model(model_path: str, config_path: Optional[str] = None):
|
83 |
+
if config_path is None:
|
84 |
+
name = os.path.basename(model_path).strip(".pkl")
|
85 |
+
assert name in ALL_MODELS.keys(), f"Cannot find configs for {model_path}. " \
|
86 |
+
f"Please provide the config file you used for training."
|
87 |
+
config = os.path.join(os.path.dirname(__file__), "configs", f"repcodec_dim{ALL_MODELS[name]}.yaml")
|
88 |
+
with open(config) as fp:
|
89 |
+
conf = yaml.load(fp, Loader=yaml.FullLoader)
|
90 |
+
else:
|
91 |
+
with open(config_path) as fp:
|
92 |
+
conf = yaml.load(fp, Loader=yaml.FullLoader)["model_params"]
|
93 |
+
|
94 |
+
model = RepCodec(**conf)
|
95 |
+
model.load_state_dict(torch.load(model_path, map_location="cpu")["model"]["repcodec"])
|
96 |
+
model.quantizer.initial()
|
97 |
+
model.eval()
|
98 |
+
return model
|
99 |
+
|
100 |
+
|
101 |
+
def load_shard(in_dir: Path, rank: int, n_shard: int) -> Tuple[np.ndarray, List[int]]:
|
102 |
+
feat_path = in_dir / f"{rank}_{n_shard}.npy"
|
103 |
+
len_path = in_dir / f"{rank}_{n_shard}.len"
|
104 |
+
|
105 |
+
with open(len_path) as fp:
|
106 |
+
lengths = [int(line.strip()) for line in fp]
|
107 |
+
|
108 |
+
return np.load(feat_path.as_posix(), mmap_mode="r"), lengths
|
109 |
+
|
110 |
+
|
111 |
+
def pad_data(data: List[np.ndarray]) -> List[np.ndarray]:
|
112 |
+
max_len = max([d.shape[0] for d in data])
|
113 |
+
data = [
|
114 |
+
np.pad(d, [(0, max_len - d.shape[0]), (0, 0)], "constant", constant_values=0.0)
|
115 |
+
for d in data
|
116 |
+
]
|
117 |
+
return data
|
118 |
+
|
119 |
+
|
120 |
+
def make_batch_data(data: np.ndarray, shard_lengths: List[int], batch_size: int):
|
121 |
+
batch_data = []
|
122 |
+
batch_lens = []
|
123 |
+
offsets = np.cumsum([0] + shard_lengths)
|
124 |
+
assert len(data) == offsets[-1], f"{len(data)} {offsets[-1]}"
|
125 |
+
|
126 |
+
# from longest to shortest
|
127 |
+
for i in range(len(shard_lengths)):
|
128 |
+
if batch_size > len(batch_data):
|
129 |
+
batch_data.append(data[offsets[i]: offsets[i + 1]])
|
130 |
+
batch_lens.append(shard_lengths[i])
|
131 |
+
else:
|
132 |
+
yield {
|
133 |
+
"data": torch.tensor(np.stack(pad_data(batch_data)), dtype=torch.float), # (bsz, seq len, hidden dim)
|
134 |
+
"lengths": batch_lens
|
135 |
+
}
|
136 |
+
batch_data = [data[offsets[i]: offsets[i + 1]]]
|
137 |
+
batch_lens = [shard_lengths[i]]
|
138 |
+
if len(batch_data) > 0:
|
139 |
+
yield {
|
140 |
+
"data": torch.tensor(np.stack(pad_data(batch_data)), dtype=torch.float),
|
141 |
+
"lengths": batch_lens
|
142 |
+
}
|
143 |
+
|
144 |
+
|
145 |
+
def tokenize_batch(model: RepCodec, batch: dict, device: str) -> List[List[int]]:
|
146 |
+
with torch.no_grad():
|
147 |
+
data = batch["data"].transpose(1, 2).to(device) # (bsz, hidden dim, seq len)
|
148 |
+
x = model.encoder(data)
|
149 |
+
z = model.projector(x)
|
150 |
+
_, idx = model.quantizer.codebook.forward_index(z.transpose(2, 1))
|
151 |
+
|
152 |
+
# when bsz=1: (1, seq len)
|
153 |
+
if idx.dim() == 2:
|
154 |
+
return idx.cpu().data.numpy().tolist()
|
155 |
+
# when bsz>1: (1, bsz, seq len)
|
156 |
+
tokens = idx.cpu().data.numpy().tolist()[0]
|
157 |
+
res = []
|
158 |
+
batch_lens = batch["lengths"]
|
159 |
+
for i in range(len(tokens)):
|
160 |
+
n_tokens = batch_lens[i]
|
161 |
+
res.append(tokens[i][:n_tokens])
|
162 |
+
return res
|
163 |
+
|
164 |
+
|
165 |
+
def load_tsv(path: str):
|
166 |
+
with open(path) as fp:
|
167 |
+
root = fp.readline().strip()
|
168 |
+
names = []
|
169 |
+
for line in fp:
|
170 |
+
names.append(line.strip().split("\t")[0])
|
171 |
+
return root, names
|
172 |
+
|
173 |
+
|
174 |
+
def cli():
|
175 |
+
args = parse_args()
|
176 |
+
device = "cuda" if args.use_gpu else "cpu"
|
177 |
+
|
178 |
+
model = load_model(model_path=args.model, config_path=args.model_config_path)
|
179 |
+
model.to(device)
|
180 |
+
|
181 |
+
in_dir = Path(args.in_dir)
|
182 |
+
n_shard = args.n_shard
|
183 |
+
batch_size = args.batch_size
|
184 |
+
|
185 |
+
root_dir, file_names = load_tsv(args.tsv_path)
|
186 |
+
|
187 |
+
output_dir = args.out_dir
|
188 |
+
os.makedirs(output_dir, exist_ok=True)
|
189 |
+
|
190 |
+
processed_cnt = 0
|
191 |
+
pbar = tqdm(total=len(file_names))
|
192 |
+
with open(os.path.join(output_dir, "tokens"), mode="w+") as fp:
|
193 |
+
fp.write(f"{root_dir}\n")
|
194 |
+
|
195 |
+
for rank in range(n_shard):
|
196 |
+
shard_data, shard_lengths = load_shard(in_dir, rank, n_shard)
|
197 |
+
for batch in make_batch_data(shard_data, shard_lengths, batch_size=batch_size):
|
198 |
+
batch_tokens = tokenize_batch(model, batch, device)
|
199 |
+
|
200 |
+
for tokens in batch_tokens:
|
201 |
+
fp.write(f"{file_names[processed_cnt]}\t{' '.join(map(str, tokens))}\n")
|
202 |
+
processed_cnt += 1
|
203 |
+
|
204 |
+
pbar.update(len(batch_tokens))
|
205 |
+
assert processed_cnt == len(file_names), f"# lines of tsv do not match # of representations!"
|
206 |
+
|
207 |
+
pbar.close()
|
208 |
+
print("Tokenize successfully!")
|
209 |
+
|
210 |
+
|
211 |
+
if __name__ == '__main__':
|
212 |
+
cli()
|
RepCodec/setup.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from setuptools import setup
|
8 |
+
|
9 |
+
try:
|
10 |
+
with open("README.md") as fp:
|
11 |
+
long_description = fp.read()
|
12 |
+
except Exception:
|
13 |
+
long_description = ""
|
14 |
+
|
15 |
+
setup(
|
16 |
+
name="RepCodec",
|
17 |
+
version="v1.0.0",
|
18 |
+
description="A Speech Representation Codec for Speech Tokenization",
|
19 |
+
long_description=long_description,
|
20 |
+
url="https://github.com/mct10/RepCodec",
|
21 |
+
packages=["repcodec", "repcodec.modules", "repcodec.layers"],
|
22 |
+
package_data={
|
23 |
+
"repcodec": ["configs/*.yaml"]
|
24 |
+
},
|
25 |
+
install_requires=["numpy", "tqdm", "torch", "tensorboardX", "PyYAML"],
|
26 |
+
entry_points={
|
27 |
+
'console_scripts': [
|
28 |
+
"repcodec=repcodec.tokenize:cli"
|
29 |
+
]
|
30 |
+
}
|
31 |
+
)
|
RepCodec/train.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the CC BY-NC license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
import logging
|
10 |
+
|
11 |
+
import os
|
12 |
+
|
13 |
+
logging.basicConfig(
|
14 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
15 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
16 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
17 |
+
)
|
18 |
+
logger = logging.getLogger("repcodec_train") # init logger before other modules
|
19 |
+
|
20 |
+
import random
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
import yaml
|
25 |
+
from torch.utils.data import DataLoader
|
26 |
+
|
27 |
+
from dataloader import ReprDataset, ReprCollater
|
28 |
+
from losses.repr_reconstruct_loss import ReprReconstructLoss
|
29 |
+
from repcodec.RepCodec import RepCodec
|
30 |
+
from trainer.autoencoder import Trainer
|
31 |
+
|
32 |
+
|
33 |
+
class TrainMain:
|
34 |
+
def __init__(self, args):
|
35 |
+
# Fix seed and make backends deterministic
|
36 |
+
random.seed(args.seed)
|
37 |
+
np.random.seed(args.seed)
|
38 |
+
torch.manual_seed(args.seed)
|
39 |
+
if not torch.cuda.is_available():
|
40 |
+
self.device = torch.device('cpu')
|
41 |
+
logger.info(f"device: cpu")
|
42 |
+
else:
|
43 |
+
self.device = torch.device('cuda:0') # only supports single gpu for now
|
44 |
+
logger.info(f"device: gpu")
|
45 |
+
torch.cuda.manual_seed_all(args.seed)
|
46 |
+
if args.disable_cudnn == "False":
|
47 |
+
torch.backends.cudnn.benchmark = True
|
48 |
+
|
49 |
+
# initialize config
|
50 |
+
with open(args.config, 'r') as f:
|
51 |
+
self.config = yaml.load(f, Loader=yaml.FullLoader)
|
52 |
+
self.config.update(vars(args))
|
53 |
+
|
54 |
+
# initialize model folder
|
55 |
+
expdir = os.path.join(args.exp_root, args.tag)
|
56 |
+
os.makedirs(expdir, exist_ok=True)
|
57 |
+
self.config["outdir"] = expdir
|
58 |
+
|
59 |
+
# save config
|
60 |
+
with open(os.path.join(expdir, "config.yml"), "w") as f:
|
61 |
+
yaml.dump(self.config, f, Dumper=yaml.Dumper)
|
62 |
+
for key, value in self.config.items():
|
63 |
+
logger.info(f"{key} = {value}")
|
64 |
+
|
65 |
+
# initialize attribute
|
66 |
+
self.resume: str = args.resume
|
67 |
+
self.data_loader = None
|
68 |
+
self.model = None
|
69 |
+
self.optimizer = None
|
70 |
+
self.scheduler = None
|
71 |
+
self.criterion = None
|
72 |
+
self.trainer = None
|
73 |
+
|
74 |
+
# initialize batch_length
|
75 |
+
self.batch_length: int = self.config['batch_length']
|
76 |
+
self.data_path: str = self.config['data']['path']
|
77 |
+
|
78 |
+
def initialize_data_loader(self):
|
79 |
+
train_set = self._build_dataset("train")
|
80 |
+
valid_set = self._build_dataset("valid")
|
81 |
+
collater = ReprCollater()
|
82 |
+
|
83 |
+
logger.info(f"The number of training files = {len(train_set)}.")
|
84 |
+
logger.info(f"The number of validation files = {len(valid_set)}.")
|
85 |
+
dataset = {"train": train_set, "dev": valid_set}
|
86 |
+
self._set_data_loader(dataset, collater)
|
87 |
+
|
88 |
+
def define_model_optimizer_scheduler(self):
|
89 |
+
# model arch
|
90 |
+
self.model = {
|
91 |
+
"repcodec": RepCodec(**self.config["model_params"]).to(self.device)
|
92 |
+
}
|
93 |
+
logger.info(f"Model Arch:\n{self.model['repcodec']}")
|
94 |
+
|
95 |
+
# opt
|
96 |
+
optimizer_class = getattr(
|
97 |
+
torch.optim,
|
98 |
+
self.config["model_optimizer_type"]
|
99 |
+
)
|
100 |
+
self.optimizer = {
|
101 |
+
"repcodec": optimizer_class(
|
102 |
+
self.model["repcodec"].parameters(),
|
103 |
+
**self.config["model_optimizer_params"]
|
104 |
+
)
|
105 |
+
}
|
106 |
+
|
107 |
+
# scheduler
|
108 |
+
scheduler_class = getattr(
|
109 |
+
torch.optim.lr_scheduler,
|
110 |
+
self.config.get("model_scheduler_type", "StepLR"),
|
111 |
+
)
|
112 |
+
self.scheduler = {
|
113 |
+
"repcodec": scheduler_class(
|
114 |
+
optimizer=self.optimizer["repcodec"],
|
115 |
+
**self.config["model_scheduler_params"]
|
116 |
+
)
|
117 |
+
}
|
118 |
+
|
119 |
+
def define_criterion(self):
|
120 |
+
self.criterion = {
|
121 |
+
"repr_reconstruct_loss": ReprReconstructLoss(
|
122 |
+
**self.config.get("repr_reconstruct_loss_params", {}),
|
123 |
+
).to(self.device)
|
124 |
+
}
|
125 |
+
|
126 |
+
def define_trainer(self):
|
127 |
+
self.trainer = Trainer(
|
128 |
+
steps=0,
|
129 |
+
epochs=0,
|
130 |
+
data_loader=self.data_loader,
|
131 |
+
model=self.model,
|
132 |
+
criterion=self.criterion,
|
133 |
+
optimizer=self.optimizer,
|
134 |
+
scheduler=self.scheduler,
|
135 |
+
config=self.config,
|
136 |
+
device=self.device
|
137 |
+
)
|
138 |
+
|
139 |
+
def initialize_model(self):
|
140 |
+
initial = self.config.get("initial", "")
|
141 |
+
if os.path.exists(self.resume): # resume from trained model
|
142 |
+
self.trainer.load_checkpoint(self.resume)
|
143 |
+
logger.info(f"Successfully resumed from {self.resume}.")
|
144 |
+
elif os.path.exists(initial): # initial new model with the pre-trained model
|
145 |
+
self.trainer.load_checkpoint(initial, load_only_params=True)
|
146 |
+
logger.info(f"Successfully initialize parameters from {initial}.")
|
147 |
+
else:
|
148 |
+
logger.info("Train from scrach")
|
149 |
+
|
150 |
+
def run(self):
|
151 |
+
assert self.trainer is not None
|
152 |
+
self.trainer: Trainer
|
153 |
+
try:
|
154 |
+
logger.info(f"The current training step: {self.trainer.steps}")
|
155 |
+
self.trainer.train_max_steps = self.config["train_max_steps"]
|
156 |
+
if not self.trainer._check_train_finish():
|
157 |
+
self.trainer.run()
|
158 |
+
finally:
|
159 |
+
self.trainer.save_checkpoint(
|
160 |
+
os.path.join(self.config["outdir"], f"checkpoint-{self.trainer.steps}steps.pkl")
|
161 |
+
)
|
162 |
+
logger.info(f"Successfully saved checkpoint @ {self.trainer.steps}steps.")
|
163 |
+
|
164 |
+
def _build_dataset(
|
165 |
+
self, subset: str
|
166 |
+
) -> ReprDataset:
|
167 |
+
data_dir = os.path.join(
|
168 |
+
self.data_path, self.config['data']['subset'][subset]
|
169 |
+
)
|
170 |
+
params = {
|
171 |
+
"data_dir": data_dir,
|
172 |
+
"batch_len": self.batch_length
|
173 |
+
}
|
174 |
+
return ReprDataset(**params)
|
175 |
+
|
176 |
+
def _set_data_loader(self, dataset, collater):
|
177 |
+
self.data_loader = {
|
178 |
+
"train": DataLoader(
|
179 |
+
dataset=dataset["train"],
|
180 |
+
shuffle=True,
|
181 |
+
collate_fn=collater,
|
182 |
+
batch_size=self.config["batch_size"],
|
183 |
+
num_workers=self.config["num_workers"],
|
184 |
+
pin_memory=self.config["pin_memory"],
|
185 |
+
),
|
186 |
+
"dev": DataLoader(
|
187 |
+
dataset=dataset["dev"],
|
188 |
+
shuffle=False,
|
189 |
+
collate_fn=collater,
|
190 |
+
batch_size=self.config["batch_size"],
|
191 |
+
num_workers=0,
|
192 |
+
pin_memory=False, # save some memory. set to True if you have enough memory.
|
193 |
+
),
|
194 |
+
}
|
195 |
+
|
196 |
+
|
197 |
+
def train():
|
198 |
+
parser = argparse.ArgumentParser()
|
199 |
+
parser.add_argument(
|
200 |
+
"-c", "--config", type=str, required=True,
|
201 |
+
help="the path of config yaml file."
|
202 |
+
)
|
203 |
+
parser.add_argument(
|
204 |
+
"--tag", type=str, required=True,
|
205 |
+
help="the outputs will be saved to exp_root/tag/"
|
206 |
+
)
|
207 |
+
parser.add_argument(
|
208 |
+
"--exp_root", type=str, default="exp"
|
209 |
+
)
|
210 |
+
parser.add_argument(
|
211 |
+
"--resume", default="", type=str, nargs="?",
|
212 |
+
help='checkpoint file path to resume training. (default="")',
|
213 |
+
)
|
214 |
+
parser.add_argument("--seed", default=1337, type=int)
|
215 |
+
parser.add_argument("--disable_cudnn", choices=("True", "False"), default="False", help="Disable CUDNN")
|
216 |
+
args = parser.parse_args()
|
217 |
+
|
218 |
+
train_main = TrainMain(args)
|
219 |
+
train_main.initialize_data_loader()
|
220 |
+
train_main.define_model_optimizer_scheduler()
|
221 |
+
train_main.define_criterion()
|
222 |
+
train_main.define_trainer()
|
223 |
+
train_main.initialize_model()
|
224 |
+
train_main.run()
|
225 |
+
|
226 |
+
|
227 |
+
if __name__ == '__main__':
|
228 |
+
train()
|
RepCodec/train_configs/ex_dim768_mse.yaml
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###########################################################
|
2 |
+
# DATA SETTING #
|
3 |
+
###########################################################
|
4 |
+
data:
|
5 |
+
path: "/dir/to/representations/"
|
6 |
+
subset:
|
7 |
+
train: "train_set_name"
|
8 |
+
valid: "valid_set_name"
|
9 |
+
test: "test_set_name"
|
10 |
+
|
11 |
+
###########################################################
|
12 |
+
# MODEL SETTING #
|
13 |
+
###########################################################
|
14 |
+
model_params:
|
15 |
+
input_channels: 768
|
16 |
+
output_channels: 768
|
17 |
+
encode_channels: 768
|
18 |
+
decode_channels: 768
|
19 |
+
code_dim: 768
|
20 |
+
codebook_num: 1
|
21 |
+
codebook_size: 1024
|
22 |
+
bias: true
|
23 |
+
enc_ratios: [1, 1]
|
24 |
+
dec_ratios: [1, 1]
|
25 |
+
enc_strides: [1, 1] # no downsampling
|
26 |
+
dec_strides: [1, 1]
|
27 |
+
enc_kernel_size: 3
|
28 |
+
dec_kernel_size: 3
|
29 |
+
enc_block_dilations: [1, 1]
|
30 |
+
enc_block_kernel_size: 3
|
31 |
+
dec_block_dilations: [1, 1]
|
32 |
+
dec_block_kernel_size: 3
|
33 |
+
|
34 |
+
###########################################################
|
35 |
+
# METRIC LOSS SETTING #
|
36 |
+
###########################################################
|
37 |
+
repr_reconstruct_loss_params:
|
38 |
+
loss_type: l2
|
39 |
+
|
40 |
+
###########################################################
|
41 |
+
# LOSS WEIGHT SETTING #
|
42 |
+
###########################################################
|
43 |
+
lambda_vq_loss: 1.0 # Loss weight of vector quantize loss.
|
44 |
+
lambda_repr_reconstruct_loss: 45.0
|
45 |
+
|
46 |
+
###########################################################
|
47 |
+
# DATA LOADER SETTING #
|
48 |
+
###########################################################
|
49 |
+
batch_size: 32 # Batch size.
|
50 |
+
batch_length: 96 # Length of each audio in batch (training w/o adv).
|
51 |
+
pin_memory: true # Whether to pin memory in Pytorch DataLoader.
|
52 |
+
num_workers: 4 # Number of workers in Pytorch DataLoader.
|
53 |
+
|
54 |
+
###########################################################
|
55 |
+
# OPTIMIZER & SCHEDULER SETTING #
|
56 |
+
###########################################################
|
57 |
+
model_optimizer_type: Adam
|
58 |
+
model_optimizer_params:
|
59 |
+
lr: 1.0e-4
|
60 |
+
betas: [0.5, 0.9]
|
61 |
+
weight_decay: 0.0
|
62 |
+
model_scheduler_type: StepLR
|
63 |
+
model_scheduler_params:
|
64 |
+
step_size: 200000 # Model's scheduler step size.
|
65 |
+
gamma: 1.0
|
66 |
+
grad_norm: -1
|
67 |
+
|
68 |
+
###########################################################
|
69 |
+
# INTERVAL SETTING #
|
70 |
+
###########################################################
|
71 |
+
train_max_steps: 200000 # Number of training steps. (w/o adv)
|
72 |
+
save_interval_steps: 20000 # Interval steps to save checkpoint.
|
73 |
+
eval_interval_steps: 2000 # Interval steps to evaluate the network.
|
74 |
+
log_interval_steps: 100 # Interval steps to record the training log.
|
RepCodec/trainer/autoencoder.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the CC BY-NC license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
from collections import defaultdict
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from tensorboardX import SummaryWriter
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
logger = logging.getLogger("repcodec_train")
|
17 |
+
|
18 |
+
|
19 |
+
class Trainer:
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
steps: int,
|
23 |
+
epochs: int,
|
24 |
+
data_loader: dict,
|
25 |
+
model: dict,
|
26 |
+
criterion: dict,
|
27 |
+
optimizer: dict,
|
28 |
+
scheduler: dict,
|
29 |
+
config: dict,
|
30 |
+
device=torch.device("cpu"),
|
31 |
+
):
|
32 |
+
self.steps = steps
|
33 |
+
self.epochs = epochs
|
34 |
+
self.data_loader = data_loader
|
35 |
+
self.model = model
|
36 |
+
self.criterion = criterion
|
37 |
+
self.optimizer = optimizer
|
38 |
+
self.scheduler = scheduler
|
39 |
+
self.config = config
|
40 |
+
self.device = device
|
41 |
+
self.writer = SummaryWriter(config["outdir"])
|
42 |
+
self.total_train_loss = defaultdict(float)
|
43 |
+
self.total_eval_loss = defaultdict(float)
|
44 |
+
self.train_max_steps = config.get("train_max_steps", 0)
|
45 |
+
|
46 |
+
def _train_step(self, batch):
|
47 |
+
"""Single step of training."""
|
48 |
+
mode = "train"
|
49 |
+
x = batch
|
50 |
+
x = x.to(self.device)
|
51 |
+
|
52 |
+
codec_loss = 0.0
|
53 |
+
y_, zq, z, vqloss, perplexity = self.model["repcodec"](x)
|
54 |
+
self._perplexity(perplexity, mode=mode)
|
55 |
+
codec_loss += self._vq_loss(vqloss, mode=mode)
|
56 |
+
codec_loss += self._metric_loss(y_, x, mode=mode)
|
57 |
+
|
58 |
+
self._record_loss("codec_loss", codec_loss, mode=mode)
|
59 |
+
self._update_repcodec(codec_loss)
|
60 |
+
|
61 |
+
self.steps += 1
|
62 |
+
self.tqdm.update(1)
|
63 |
+
self._check_train_finish()
|
64 |
+
|
65 |
+
@torch.no_grad()
|
66 |
+
def _eval_step(self, batch):
|
67 |
+
"""Single step of evaluation."""
|
68 |
+
mode = "eval"
|
69 |
+
x = batch
|
70 |
+
x = x.to(self.device)
|
71 |
+
|
72 |
+
codec_loss = 0.0
|
73 |
+
y_, zq, z, vqloss, perplexity = self.model["repcodec"](x)
|
74 |
+
self._perplexity(perplexity, mode=mode)
|
75 |
+
codec_loss += self._vq_loss(vqloss, mode=mode)
|
76 |
+
codec_loss += self._metric_loss(y_, x, mode=mode)
|
77 |
+
|
78 |
+
self._record_loss("codec_loss", codec_loss, mode=mode)
|
79 |
+
|
80 |
+
def run(self):
|
81 |
+
"""Run training."""
|
82 |
+
self.finish_train = False
|
83 |
+
self.tqdm = tqdm(
|
84 |
+
initial=self.steps, total=self.train_max_steps, desc="[train]"
|
85 |
+
)
|
86 |
+
while True:
|
87 |
+
self._train_epoch()
|
88 |
+
|
89 |
+
# check whether training is finished
|
90 |
+
if self.finish_train:
|
91 |
+
break
|
92 |
+
|
93 |
+
self.tqdm.close()
|
94 |
+
logger.info("Finished training.")
|
95 |
+
|
96 |
+
def save_checkpoint(self, checkpoint_path: str):
|
97 |
+
state_dict = {
|
98 |
+
"model": {
|
99 |
+
"repcodec": self.model["repcodec"].state_dict()
|
100 |
+
},
|
101 |
+
"optimizer": {
|
102 |
+
"repcodec": self.optimizer["repcodec"].state_dict(),
|
103 |
+
},
|
104 |
+
"scheduler": {
|
105 |
+
"repcodec": self.scheduler["repcodec"].state_dict(),
|
106 |
+
},
|
107 |
+
"steps": self.steps,
|
108 |
+
"epochs": self.epochs,
|
109 |
+
}
|
110 |
+
|
111 |
+
if not os.path.exists(os.path.dirname(checkpoint_path)):
|
112 |
+
os.makedirs(os.path.dirname(checkpoint_path))
|
113 |
+
torch.save(state_dict, checkpoint_path)
|
114 |
+
|
115 |
+
def load_checkpoint(
|
116 |
+
self,
|
117 |
+
checkpoint_path: str,
|
118 |
+
strict: bool = True,
|
119 |
+
load_only_params: bool = False
|
120 |
+
):
|
121 |
+
state_dict = torch.load(checkpoint_path, map_location="cpu")
|
122 |
+
self.model["repcodec"].load_state_dict(
|
123 |
+
state_dict["model"]["repcodec"], strict=strict
|
124 |
+
)
|
125 |
+
|
126 |
+
if not load_only_params:
|
127 |
+
self.steps = state_dict["steps"]
|
128 |
+
self.epochs = state_dict["epochs"]
|
129 |
+
self.optimizer["repcodec"].load_state_dict(
|
130 |
+
state_dict["optimizer"]["repcodec"]
|
131 |
+
)
|
132 |
+
self.scheduler["repcodec"].load_state_dict(
|
133 |
+
state_dict["scheduler"]["repcodec"]
|
134 |
+
)
|
135 |
+
|
136 |
+
def _train_epoch(self):
|
137 |
+
"""One epoch of training."""
|
138 |
+
for train_steps_per_epoch, batch in enumerate(self.data_loader["train"], 1):
|
139 |
+
# train one step
|
140 |
+
self._train_step(batch)
|
141 |
+
|
142 |
+
# check interval
|
143 |
+
self._check_log_interval()
|
144 |
+
self._check_eval_interval()
|
145 |
+
self._check_save_interval()
|
146 |
+
|
147 |
+
# check whether training is finished
|
148 |
+
if self.finish_train:
|
149 |
+
return
|
150 |
+
|
151 |
+
# update
|
152 |
+
self.epochs += 1
|
153 |
+
self.train_steps_per_epoch = train_steps_per_epoch
|
154 |
+
if train_steps_per_epoch > 200:
|
155 |
+
logger.info(
|
156 |
+
f"(Steps: {self.steps}) Finished {self.epochs} epoch training "
|
157 |
+
f"({self.train_steps_per_epoch} steps per epoch)."
|
158 |
+
)
|
159 |
+
|
160 |
+
def _eval_epoch(self):
|
161 |
+
"""One epoch of evaluation."""
|
162 |
+
logger.info(f"(Steps: {self.steps}) Start evaluation.")
|
163 |
+
# change mode
|
164 |
+
for key in self.model.keys():
|
165 |
+
self.model[key].eval()
|
166 |
+
|
167 |
+
# calculate loss for each batch
|
168 |
+
for eval_steps_per_epoch, batch in enumerate(
|
169 |
+
tqdm(self.data_loader["dev"], desc="[eval]"), 1
|
170 |
+
):
|
171 |
+
# eval one step
|
172 |
+
self._eval_step(batch)
|
173 |
+
|
174 |
+
logger.info(
|
175 |
+
f"(Steps: {self.steps}) Finished evaluation "
|
176 |
+
f"({eval_steps_per_epoch} steps per epoch)."
|
177 |
+
)
|
178 |
+
|
179 |
+
# average loss
|
180 |
+
for key in self.total_eval_loss.keys():
|
181 |
+
self.total_eval_loss[key] /= eval_steps_per_epoch
|
182 |
+
logger.info(
|
183 |
+
f"(Steps: {self.steps}) {key} = {self.total_eval_loss[key]:.4f}."
|
184 |
+
)
|
185 |
+
|
186 |
+
# record
|
187 |
+
self._write_to_tensorboard(self.total_eval_loss)
|
188 |
+
|
189 |
+
# reset
|
190 |
+
self.total_eval_loss = defaultdict(float)
|
191 |
+
|
192 |
+
# restore mode
|
193 |
+
for key in self.model.keys():
|
194 |
+
self.model[key].train()
|
195 |
+
|
196 |
+
def _metric_loss(self, predict_y, natural_y, mode='train'):
|
197 |
+
"""Metric losses."""
|
198 |
+
metric_loss = 0.0
|
199 |
+
|
200 |
+
repr_reconstruct_loss = self.criterion["repr_reconstruct_loss"](predict_y, natural_y)
|
201 |
+
repr_reconstruct_loss *= self.config["lambda_repr_reconstruct_loss"]
|
202 |
+
self._record_loss("reconstruct_loss", repr_reconstruct_loss, mode=mode)
|
203 |
+
metric_loss += repr_reconstruct_loss
|
204 |
+
|
205 |
+
return metric_loss
|
206 |
+
|
207 |
+
def _update_repcodec(self, repr_loss):
|
208 |
+
"""Update generator."""
|
209 |
+
self.optimizer["repcodec"].zero_grad()
|
210 |
+
repr_loss.backward()
|
211 |
+
if self.config["grad_norm"] > 0:
|
212 |
+
torch.nn.utils.clip_grad_norm_(
|
213 |
+
self.model["repcodec"].parameters(),
|
214 |
+
self.config["grad_norm"],
|
215 |
+
)
|
216 |
+
self.optimizer["repcodec"].step()
|
217 |
+
self.scheduler["repcodec"].step()
|
218 |
+
|
219 |
+
def _record_loss(self, name: str, loss, mode='train'):
|
220 |
+
"""Record loss."""
|
221 |
+
if torch.is_tensor(loss):
|
222 |
+
loss = loss.item()
|
223 |
+
|
224 |
+
if mode == 'train':
|
225 |
+
self.total_train_loss[f"train/{name}"] += loss
|
226 |
+
elif mode == 'eval':
|
227 |
+
self.total_eval_loss[f"eval/{name}"] += loss
|
228 |
+
else:
|
229 |
+
raise NotImplementedError(f"Mode ({mode}) is not supported!")
|
230 |
+
|
231 |
+
def _write_to_tensorboard(self, loss):
|
232 |
+
"""Write to tensorboard."""
|
233 |
+
for key, value in loss.items():
|
234 |
+
self.writer.add_scalar(key, value, self.steps)
|
235 |
+
|
236 |
+
def _check_save_interval(self):
|
237 |
+
if self.steps and (self.steps % self.config["save_interval_steps"] == 0):
|
238 |
+
self.save_checkpoint(
|
239 |
+
os.path.join(self.config["outdir"], f"checkpoint-{self.steps}steps.pkl")
|
240 |
+
)
|
241 |
+
logger.info(f"Successfully saved checkpoint @ {self.steps} steps.")
|
242 |
+
|
243 |
+
def _check_eval_interval(self):
|
244 |
+
if self.steps % self.config["eval_interval_steps"] == 0:
|
245 |
+
self._eval_epoch()
|
246 |
+
|
247 |
+
def _check_log_interval(self):
|
248 |
+
if self.steps % self.config["log_interval_steps"] == 0:
|
249 |
+
for key in self.total_train_loss.keys():
|
250 |
+
self.total_train_loss[key] /= self.config["log_interval_steps"]
|
251 |
+
logger.info(
|
252 |
+
f"(Steps: {self.steps}) {key} = {self.total_train_loss[key]:.4f}."
|
253 |
+
)
|
254 |
+
self._write_to_tensorboard(self.total_train_loss)
|
255 |
+
|
256 |
+
# reset
|
257 |
+
self.total_train_loss = defaultdict(float)
|
258 |
+
|
259 |
+
def _check_train_finish(self):
|
260 |
+
if self.steps >= self.train_max_steps:
|
261 |
+
self.finish_train = True
|
262 |
+
else:
|
263 |
+
self.finish_train = False
|
264 |
+
return self.finish_train
|
265 |
+
|
266 |
+
def _perplexity(self, perplexity, label=None, mode='train'):
|
267 |
+
if label:
|
268 |
+
name = f"{mode}/ppl_{label}"
|
269 |
+
else:
|
270 |
+
name = f"{mode}/ppl"
|
271 |
+
if torch.numel(perplexity) > 1:
|
272 |
+
perplexity = perplexity.tolist()
|
273 |
+
for idx, ppl in enumerate(perplexity):
|
274 |
+
self._record_loss(f"{name}_{idx}", ppl, mode=mode)
|
275 |
+
else:
|
276 |
+
self._record_loss(name, perplexity, mode=mode)
|
277 |
+
|
278 |
+
def _vq_loss(self, vqloss, label=None, mode='train'):
|
279 |
+
if label:
|
280 |
+
name = f"{mode}/vqloss_{label}"
|
281 |
+
else:
|
282 |
+
name = f"{mode}/vqloss"
|
283 |
+
vqloss = torch.sum(vqloss)
|
284 |
+
vqloss *= self.config["lambda_vq_loss"]
|
285 |
+
self._record_loss(name, vqloss, mode=mode)
|
286 |
+
|
287 |
+
return vqloss
|
__pycache__/post_process_audio.cpython-310.pyc
ADDED
Binary file (3.1 kB). View file
|
|
__pycache__/vocoder.cpython-310.pyc
ADDED
Binary file (5.75 kB). View file
|
|
decoders/config.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
backbone:
|
2 |
+
class_path: vocos.models.VocosBackbone
|
3 |
+
init_args:
|
4 |
+
input_channels: 1024
|
5 |
+
dim: 512
|
6 |
+
intermediate_dim: 1536
|
7 |
+
num_layers: 8
|
8 |
+
|
9 |
+
head:
|
10 |
+
class_path: vocos.heads.ISTFTHead
|
11 |
+
init_args:
|
12 |
+
dim: 512
|
13 |
+
n_fft: 3528
|
14 |
+
hop_length: 882
|
15 |
+
padding: same
|
decoders/decoder_131000.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b99f0be84eeef3a32f29cd55beb89727fd0b2fd0df3dbad3023508f4c7185c37
|
3 |
+
size 72611958
|
decoders/decoder_151000.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8af97a29d3483f9d4a3755992837501bd7d6caa1a69382ed16e64039e0ea0998
|
3 |
+
size 72610550
|
descriptaudiocodec/dac/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__version__ = "1.0.0"
|
2 |
+
|
3 |
+
# preserved here for legacy reasons
|
4 |
+
__model_version__ = "latest"
|
5 |
+
|
6 |
+
import audiotools
|
7 |
+
|
8 |
+
audiotools.ml.BaseModel.INTERN += ["dac.**"]
|
9 |
+
audiotools.ml.BaseModel.EXTERN += ["einops"]
|
10 |
+
|
11 |
+
|
12 |
+
from . import nn
|
13 |
+
from . import model
|
14 |
+
from . import utils
|
15 |
+
from .model import DAC
|
16 |
+
from .model import DACFile
|
descriptaudiocodec/dac/__main__.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
import argbind
|
4 |
+
|
5 |
+
from dac.utils import download
|
6 |
+
from dac.utils.decode import decode
|
7 |
+
from dac.utils.encode import encode
|
8 |
+
|
9 |
+
STAGES = ["encode", "decode", "download"]
|
10 |
+
|
11 |
+
|
12 |
+
def run(stage: str):
|
13 |
+
"""Run stages.
|
14 |
+
|
15 |
+
Parameters
|
16 |
+
----------
|
17 |
+
stage : str
|
18 |
+
Stage to run
|
19 |
+
"""
|
20 |
+
if stage not in STAGES:
|
21 |
+
raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}")
|
22 |
+
stage_fn = globals()[stage]
|
23 |
+
|
24 |
+
if stage == "download":
|
25 |
+
stage_fn()
|
26 |
+
return
|
27 |
+
|
28 |
+
stage_fn()
|
29 |
+
|
30 |
+
|
31 |
+
if __name__ == "__main__":
|
32 |
+
group = sys.argv.pop(1)
|
33 |
+
args = argbind.parse_args(group=group)
|
34 |
+
|
35 |
+
with argbind.scope(args):
|
36 |
+
run(group)
|
descriptaudiocodec/dac/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (516 Bytes). View file
|
|
descriptaudiocodec/dac/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (488 Bytes). View file
|
|
descriptaudiocodec/dac/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (489 Bytes). View file
|
|
descriptaudiocodec/dac/compare/__init__.py
ADDED
File without changes
|
descriptaudiocodec/dac/compare/encodec.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from audiotools import AudioSignal
|
3 |
+
from audiotools.ml import BaseModel
|
4 |
+
from encodec import EncodecModel
|
5 |
+
|
6 |
+
|
7 |
+
class Encodec(BaseModel):
|
8 |
+
def __init__(self, sample_rate: int = 24000, bandwidth: float = 24.0):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
if sample_rate == 24000:
|
12 |
+
self.model = EncodecModel.encodec_model_24khz()
|
13 |
+
else:
|
14 |
+
self.model = EncodecModel.encodec_model_48khz()
|
15 |
+
self.model.set_target_bandwidth(bandwidth)
|
16 |
+
self.sample_rate = 44100
|
17 |
+
|
18 |
+
def forward(
|
19 |
+
self,
|
20 |
+
audio_data: torch.Tensor,
|
21 |
+
sample_rate: int = 44100,
|
22 |
+
n_quantizers: int = None,
|
23 |
+
):
|
24 |
+
signal = AudioSignal(audio_data, sample_rate)
|
25 |
+
signal.resample(self.model.sample_rate)
|
26 |
+
recons = self.model(signal.audio_data)
|
27 |
+
recons = AudioSignal(recons, self.model.sample_rate)
|
28 |
+
recons.resample(sample_rate)
|
29 |
+
return {"audio": recons.audio_data}
|
30 |
+
|
31 |
+
|
32 |
+
if __name__ == "__main__":
|
33 |
+
import numpy as np
|
34 |
+
from functools import partial
|
35 |
+
|
36 |
+
model = Encodec()
|
37 |
+
|
38 |
+
for n, m in model.named_modules():
|
39 |
+
o = m.extra_repr()
|
40 |
+
p = sum([np.prod(p.size()) for p in m.parameters()])
|
41 |
+
fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
|
42 |
+
setattr(m, "extra_repr", partial(fn, o=o, p=p))
|
43 |
+
print(model)
|
44 |
+
print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
|
45 |
+
|
46 |
+
length = 88200 * 2
|
47 |
+
x = torch.randn(1, 1, length).to(model.device)
|
48 |
+
x.requires_grad_(True)
|
49 |
+
x.retain_grad()
|
50 |
+
|
51 |
+
# Make a forward pass
|
52 |
+
out = model(x)["audio"]
|
53 |
+
|
54 |
+
print(x.shape, out.shape)
|
descriptaudiocodec/dac/model/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base import CodecMixin
|
2 |
+
from .base import DACFile
|
3 |
+
from .dac import DAC
|
4 |
+
from .discriminator import Discriminator
|
descriptaudiocodec/dac/model/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (357 Bytes). View file
|
|
descriptaudiocodec/dac/model/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (330 Bytes). View file
|
|
descriptaudiocodec/dac/model/__pycache__/base.cpython-310.pyc
ADDED
Binary file (7.26 kB). View file
|
|
descriptaudiocodec/dac/model/__pycache__/base.cpython-39.pyc
ADDED
Binary file (7.18 kB). View file
|
|
descriptaudiocodec/dac/model/__pycache__/dac.cpython-310.pyc
ADDED
Binary file (10.8 kB). View file
|
|