Spaces:
Running
on
A10G
Running
on
A10G
QOL UI improvements
#1
by
multimodalart
HF staff
- opened
This view is limited to 50 files because it contains too many changes.
See the raw diff here.
- .github/actions/audiocraft_build/action.yml +0 -2
- .github/workflows/audiocraft_docs.yml +3 -3
- .github/workflows/audiocraft_tests.yml +1 -6
- .gitignore +1 -8
- CHANGELOG.md +2 -46
- CONTRIBUTING.md +2 -2
- LICENSE_weights +157 -399
- MANIFEST.in +0 -7
- model_cards/MUSICGEN_MODEL_CARD.md → MODEL_CARD.md +8 -32
- Makefile +4 -23
- README.md +64 -48
- app.py +136 -0
- app_batched.py +128 -0
- assets/a_duck_quacking_as_birds_chirp_and_a_pigeon_cooing.mp3 +0 -0
- assets/sirens_and_a_humming_engine_approach_and_pass.mp3 +0 -0
- audiocraft/__init__.py +1 -17
- audiocraft/adversarial/__init__.py +0 -22
- audiocraft/adversarial/discriminators/__init__.py +0 -10
- audiocraft/adversarial/discriminators/base.py +0 -34
- audiocraft/adversarial/discriminators/mpd.py +0 -106
- audiocraft/adversarial/discriminators/msd.py +0 -126
- audiocraft/adversarial/discriminators/msstftd.py +0 -134
- audiocraft/adversarial/losses.py +0 -228
- audiocraft/data/__init__.py +1 -3
- audiocraft/data/audio.py +21 -39
- audiocraft/data/audio_dataset.py +31 -93
- audiocraft/data/audio_utils.py +14 -21
- audiocraft/data/info_audio_dataset.py +0 -110
- audiocraft/data/music_dataset.py +0 -270
- audiocraft/data/sound_dataset.py +0 -330
- audiocraft/data/zip.py +6 -8
- audiocraft/environment.py +0 -176
- audiocraft/grids/__init__.py +0 -6
- audiocraft/grids/_base_explorers.py +0 -80
- audiocraft/grids/audiogen/__init__.py +0 -6
- audiocraft/grids/audiogen/audiogen_base_16khz.py +0 -23
- audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py +0 -68
- audiocraft/grids/compression/__init__.py +0 -6
- audiocraft/grids/compression/_explorers.py +0 -55
- audiocraft/grids/compression/debug.py +0 -31
- audiocraft/grids/compression/encodec_audiogen_16khz.py +0 -29
- audiocraft/grids/compression/encodec_base_24khz.py +0 -28
- audiocraft/grids/compression/encodec_musicgen_32khz.py +0 -34
- audiocraft/grids/diffusion/4_bands_base_32khz.py +0 -27
- audiocraft/grids/diffusion/__init__.py +0 -6
- audiocraft/grids/diffusion/_explorers.py +0 -66
- audiocraft/grids/musicgen/__init__.py +0 -6
- audiocraft/grids/musicgen/_explorers.py +0 -93
- audiocraft/grids/musicgen/musicgen_base_32khz.py +0 -43
- audiocraft/grids/musicgen/musicgen_base_cached_32khz.py +0 -67
.github/actions/audiocraft_build/action.yml
CHANGED
@@ -21,8 +21,6 @@ runs:
|
|
21 |
python3 -m venv env
|
22 |
. env/bin/activate
|
23 |
python -m pip install --upgrade pip
|
24 |
-
pip install torch torchvision torchaudio
|
25 |
-
pip install xformers
|
26 |
pip install -e '.[dev]'
|
27 |
- name: System Dependencies
|
28 |
shell: bash
|
|
|
21 |
python3 -m venv env
|
22 |
. env/bin/activate
|
23 |
python -m pip install --upgrade pip
|
|
|
|
|
24 |
pip install -e '.[dev]'
|
25 |
- name: System Dependencies
|
26 |
shell: bash
|
.github/workflows/audiocraft_docs.yml
CHANGED
@@ -23,9 +23,9 @@ jobs:
|
|
23 |
- name: Make docs
|
24 |
run: |
|
25 |
. env/bin/activate
|
26 |
-
make
|
27 |
-
git add -f
|
28 |
-
git commit -m
|
29 |
|
30 |
- name: Push branch
|
31 |
run: |
|
|
|
23 |
- name: Make docs
|
24 |
run: |
|
25 |
. env/bin/activate
|
26 |
+
make docs
|
27 |
+
git add -f docs
|
28 |
+
git commit -m docs
|
29 |
|
30 |
- name: Push branch
|
31 |
run: |
|
.github/workflows/audiocraft_tests.yml
CHANGED
@@ -12,11 +12,6 @@ jobs:
|
|
12 |
steps:
|
13 |
- uses: actions/checkout@v2
|
14 |
- uses: ./.github/actions/audiocraft_build
|
15 |
-
-
|
16 |
-
run: |
|
17 |
. env/bin/activate
|
18 |
make tests
|
19 |
-
- name: Run integration tests
|
20 |
-
run: |
|
21 |
-
. env/bin/activate
|
22 |
-
make tests_integ
|
|
|
12 |
steps:
|
13 |
- uses: actions/checkout@v2
|
14 |
- uses: ./.github/actions/audiocraft_build
|
15 |
+
- run: |
|
|
|
16 |
. env/bin/activate
|
17 |
make tests
|
|
|
|
|
|
|
|
.gitignore
CHANGED
@@ -35,7 +35,7 @@ wheels/
|
|
35 |
.coverage
|
36 |
|
37 |
# docs
|
38 |
-
/
|
39 |
|
40 |
# dotenv
|
41 |
.env
|
@@ -46,13 +46,6 @@ wheels/
|
|
46 |
venv/
|
47 |
ENV/
|
48 |
|
49 |
-
# egs with manifest files
|
50 |
-
egs/*
|
51 |
-
!egs/example
|
52 |
-
# local datasets
|
53 |
-
dataset/*
|
54 |
-
!dataset/example
|
55 |
-
|
56 |
# personal notebooks & scripts
|
57 |
*/local_scripts
|
58 |
*/notes
|
|
|
35 |
.coverage
|
36 |
|
37 |
# docs
|
38 |
+
/docs
|
39 |
|
40 |
# dotenv
|
41 |
.env
|
|
|
46 |
venv/
|
47 |
ENV/
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
# personal notebooks & scripts
|
50 |
*/local_scripts
|
51 |
*/notes
|
CHANGELOG.md
CHANGED
@@ -4,50 +4,6 @@ All notable changes to this project will be documented in this file.
|
|
4 |
|
5 |
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
|
6 |
|
7 |
-
## [
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
## [1.1.0] - 2023-11-06
|
13 |
-
|
14 |
-
Not using torchaudio anymore when writing audio files, relying instead directly on the commandline ffmpeg. Also not using it anymore for reading audio files, for similar reasons.
|
15 |
-
|
16 |
-
Fixed DAC support with non default number of codebooks.
|
17 |
-
|
18 |
-
Fixed bug when `two_step_cfg` was overriden when calling `generate()`.
|
19 |
-
|
20 |
-
Fixed samples being always prompted with audio, rather than having both prompted and unprompted.
|
21 |
-
|
22 |
-
**Backward incompatible change:** A `torch.no_grad` around the computation of the conditioning made its way in the public release.
|
23 |
-
The released models were trained without this. Those impact linear layers applied to the output of the T5 or melody conditioners.
|
24 |
-
We removed it, so you might need to retrain models.
|
25 |
-
|
26 |
-
**Backward incompatible change:** Fixing wrong sample rate in CLAP (WARNING if you trained model with CLAP before).
|
27 |
-
|
28 |
-
**Backward incompatible change:** Renamed VALLEPattern to CoarseFirstPattern, as it was wrongly named. Probably no one
|
29 |
-
retrained a model with this pattern, so hopefully this won't impact you!
|
30 |
-
|
31 |
-
|
32 |
-
## [1.0.0] - 2023-09-07
|
33 |
-
|
34 |
-
Major revision, added training code for EnCodec, AudioGen, MusicGen, and MultiBandDiffusion.
|
35 |
-
Added pretrained model for AudioGen and MultiBandDiffusion.
|
36 |
-
|
37 |
-
## [0.0.2] - 2023-08-01
|
38 |
-
|
39 |
-
Improved demo, fixed top p (thanks @jnordberg).
|
40 |
-
|
41 |
-
Compressor tanh on output to avoid clipping with some style (especially piano).
|
42 |
-
Now repeating the conditioning periodically if it is too short.
|
43 |
-
|
44 |
-
More options when launching Gradio app locally (thanks @ashleykleynhans).
|
45 |
-
|
46 |
-
Testing out PyTorch 2.0 memory efficient attention.
|
47 |
-
|
48 |
-
Added extended generation (infinite length) by slowly moving the windows.
|
49 |
-
Note that other implementations exist: https://github.com/camenduru/MusicGen-colab.
|
50 |
-
|
51 |
-
## [0.0.1] - 2023-06-09
|
52 |
-
|
53 |
-
Initial release, with model evaluation only.
|
|
|
4 |
|
5 |
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
|
6 |
|
7 |
+
## [0.0.1a] - TBD
|
8 |
|
9 |
+
Initial release, with model evaluation only.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CONTRIBUTING.md
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
-
# Contributing to
|
2 |
|
3 |
We want to make contributing to this project as easy and transparent as
|
4 |
possible.
|
5 |
|
6 |
## Pull Requests
|
7 |
|
8 |
-
|
9 |
Therefore, we do not plan on accepting many pull requests for new features.
|
10 |
We certainly welcome them for bug fixes.
|
11 |
|
|
|
1 |
+
# Contributing to Audiocraft
|
2 |
|
3 |
We want to make contributing to this project as easy and transparent as
|
4 |
possible.
|
5 |
|
6 |
## Pull Requests
|
7 |
|
8 |
+
Audiocraft is the implementation of a research paper.
|
9 |
Therefore, we do not plan on accepting many pull requests for new features.
|
10 |
We certainly welcome them for bug fixes.
|
11 |
|
LICENSE_weights
CHANGED
@@ -1,399 +1,157 @@
|
|
1 |
-
Attribution-NonCommercial 4.0 International
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
Creative Commons public licenses
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
terms and conditions,
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
its terms and conditions.
|
159 |
-
|
160 |
-
3. Term. The term of this Public License is specified in Section
|
161 |
-
6(a).
|
162 |
-
|
163 |
-
4. Media and formats; technical modifications allowed. The
|
164 |
-
Licensor authorizes You to exercise the Licensed Rights in
|
165 |
-
all media and formats whether now known or hereafter created,
|
166 |
-
and to make technical modifications necessary to do so. The
|
167 |
-
Licensor waives and/or agrees not to assert any right or
|
168 |
-
authority to forbid You from making technical modifications
|
169 |
-
necessary to exercise the Licensed Rights, including
|
170 |
-
technical modifications necessary to circumvent Effective
|
171 |
-
Technological Measures. For purposes of this Public License,
|
172 |
-
simply making modifications authorized by this Section 2(a)
|
173 |
-
(4) never produces Adapted Material.
|
174 |
-
|
175 |
-
5. Downstream recipients.
|
176 |
-
|
177 |
-
a. Offer from the Licensor -- Licensed Material. Every
|
178 |
-
recipient of the Licensed Material automatically
|
179 |
-
receives an offer from the Licensor to exercise the
|
180 |
-
Licensed Rights under the terms and conditions of this
|
181 |
-
Public License.
|
182 |
-
|
183 |
-
b. No downstream restrictions. You may not offer or impose
|
184 |
-
any additional or different terms or conditions on, or
|
185 |
-
apply any Effective Technological Measures to, the
|
186 |
-
Licensed Material if doing so restricts exercise of the
|
187 |
-
Licensed Rights by any recipient of the Licensed
|
188 |
-
Material.
|
189 |
-
|
190 |
-
6. No endorsement. Nothing in this Public License constitutes or
|
191 |
-
may be construed as permission to assert or imply that You
|
192 |
-
are, or that Your use of the Licensed Material is, connected
|
193 |
-
with, or sponsored, endorsed, or granted official status by,
|
194 |
-
the Licensor or others designated to receive attribution as
|
195 |
-
provided in Section 3(a)(1)(A)(i).
|
196 |
-
|
197 |
-
b. Other rights.
|
198 |
-
|
199 |
-
1. Moral rights, such as the right of integrity, are not
|
200 |
-
licensed under this Public License, nor are publicity,
|
201 |
-
privacy, and/or other similar personality rights; however, to
|
202 |
-
the extent possible, the Licensor waives and/or agrees not to
|
203 |
-
assert any such rights held by the Licensor to the limited
|
204 |
-
extent necessary to allow You to exercise the Licensed
|
205 |
-
Rights, but not otherwise.
|
206 |
-
|
207 |
-
2. Patent and trademark rights are not licensed under this
|
208 |
-
Public License.
|
209 |
-
|
210 |
-
3. To the extent possible, the Licensor waives any right to
|
211 |
-
collect royalties from You for the exercise of the Licensed
|
212 |
-
Rights, whether directly or through a collecting society
|
213 |
-
under any voluntary or waivable statutory or compulsory
|
214 |
-
licensing scheme. In all other cases the Licensor expressly
|
215 |
-
reserves any right to collect such royalties, including when
|
216 |
-
the Licensed Material is used other than for NonCommercial
|
217 |
-
purposes.
|
218 |
-
|
219 |
-
Section 3 -- License Conditions.
|
220 |
-
|
221 |
-
Your exercise of the Licensed Rights is expressly made subject to the
|
222 |
-
following conditions.
|
223 |
-
|
224 |
-
a. Attribution.
|
225 |
-
|
226 |
-
1. If You Share the Licensed Material (including in modified
|
227 |
-
form), You must:
|
228 |
-
|
229 |
-
a. retain the following if it is supplied by the Licensor
|
230 |
-
with the Licensed Material:
|
231 |
-
|
232 |
-
i. identification of the creator(s) of the Licensed
|
233 |
-
Material and any others designated to receive
|
234 |
-
attribution, in any reasonable manner requested by
|
235 |
-
the Licensor (including by pseudonym if
|
236 |
-
designated);
|
237 |
-
|
238 |
-
ii. a copyright notice;
|
239 |
-
|
240 |
-
iii. a notice that refers to this Public License;
|
241 |
-
|
242 |
-
iv. a notice that refers to the disclaimer of
|
243 |
-
warranties;
|
244 |
-
|
245 |
-
v. a URI or hyperlink to the Licensed Material to the
|
246 |
-
extent reasonably practicable;
|
247 |
-
|
248 |
-
b. indicate if You modified the Licensed Material and
|
249 |
-
retain an indication of any previous modifications; and
|
250 |
-
|
251 |
-
c. indicate the Licensed Material is licensed under this
|
252 |
-
Public License, and include the text of, or the URI or
|
253 |
-
hyperlink to, this Public License.
|
254 |
-
|
255 |
-
2. You may satisfy the conditions in Section 3(a)(1) in any
|
256 |
-
reasonable manner based on the medium, means, and context in
|
257 |
-
which You Share the Licensed Material. For example, it may be
|
258 |
-
reasonable to satisfy the conditions by providing a URI or
|
259 |
-
hyperlink to a resource that includes the required
|
260 |
-
information.
|
261 |
-
|
262 |
-
3. If requested by the Licensor, You must remove any of the
|
263 |
-
information required by Section 3(a)(1)(A) to the extent
|
264 |
-
reasonably practicable.
|
265 |
-
|
266 |
-
4. If You Share Adapted Material You produce, the Adapter's
|
267 |
-
License You apply must not prevent recipients of the Adapted
|
268 |
-
Material from complying with this Public License.
|
269 |
-
|
270 |
-
Section 4 -- Sui Generis Database Rights.
|
271 |
-
|
272 |
-
Where the Licensed Rights include Sui Generis Database Rights that
|
273 |
-
apply to Your use of the Licensed Material:
|
274 |
-
|
275 |
-
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
276 |
-
to extract, reuse, reproduce, and Share all or a substantial
|
277 |
-
portion of the contents of the database for NonCommercial purposes
|
278 |
-
only;
|
279 |
-
|
280 |
-
b. if You include all or a substantial portion of the database
|
281 |
-
contents in a database in which You have Sui Generis Database
|
282 |
-
Rights, then the database in which You have Sui Generis Database
|
283 |
-
Rights (but not its individual contents) is Adapted Material; and
|
284 |
-
|
285 |
-
c. You must comply with the conditions in Section 3(a) if You Share
|
286 |
-
all or a substantial portion of the contents of the database.
|
287 |
-
|
288 |
-
For the avoidance of doubt, this Section 4 supplements and does not
|
289 |
-
replace Your obligations under this Public License where the Licensed
|
290 |
-
Rights include other Copyright and Similar Rights.
|
291 |
-
|
292 |
-
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
293 |
-
|
294 |
-
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
295 |
-
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
296 |
-
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
297 |
-
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
298 |
-
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
299 |
-
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
300 |
-
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
301 |
-
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
302 |
-
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
303 |
-
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
304 |
-
|
305 |
-
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
306 |
-
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
307 |
-
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
308 |
-
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
309 |
-
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
310 |
-
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
311 |
-
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
312 |
-
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
313 |
-
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
314 |
-
|
315 |
-
c. The disclaimer of warranties and limitation of liability provided
|
316 |
-
above shall be interpreted in a manner that, to the extent
|
317 |
-
possible, most closely approximates an absolute disclaimer and
|
318 |
-
waiver of all liability.
|
319 |
-
|
320 |
-
Section 6 -- Term and Termination.
|
321 |
-
|
322 |
-
a. This Public License applies for the term of the Copyright and
|
323 |
-
Similar Rights licensed here. However, if You fail to comply with
|
324 |
-
this Public License, then Your rights under this Public License
|
325 |
-
terminate automatically.
|
326 |
-
|
327 |
-
b. Where Your right to use the Licensed Material has terminated under
|
328 |
-
Section 6(a), it reinstates:
|
329 |
-
|
330 |
-
1. automatically as of the date the violation is cured, provided
|
331 |
-
it is cured within 30 days of Your discovery of the
|
332 |
-
violation; or
|
333 |
-
|
334 |
-
2. upon express reinstatement by the Licensor.
|
335 |
-
|
336 |
-
For the avoidance of doubt, this Section 6(b) does not affect any
|
337 |
-
right the Licensor may have to seek remedies for Your violations
|
338 |
-
of this Public License.
|
339 |
-
|
340 |
-
c. For the avoidance of doubt, the Licensor may also offer the
|
341 |
-
Licensed Material under separate terms or conditions or stop
|
342 |
-
distributing the Licensed Material at any time; however, doing so
|
343 |
-
will not terminate this Public License.
|
344 |
-
|
345 |
-
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
346 |
-
License.
|
347 |
-
|
348 |
-
Section 7 -- Other Terms and Conditions.
|
349 |
-
|
350 |
-
a. The Licensor shall not be bound by any additional or different
|
351 |
-
terms or conditions communicated by You unless expressly agreed.
|
352 |
-
|
353 |
-
b. Any arrangements, understandings, or agreements regarding the
|
354 |
-
Licensed Material not stated herein are separate from and
|
355 |
-
independent of the terms and conditions of this Public License.
|
356 |
-
|
357 |
-
Section 8 -- Interpretation.
|
358 |
-
|
359 |
-
a. For the avoidance of doubt, this Public License does not, and
|
360 |
-
shall not be interpreted to, reduce, limit, restrict, or impose
|
361 |
-
conditions on any use of the Licensed Material that could lawfully
|
362 |
-
be made without permission under this Public License.
|
363 |
-
|
364 |
-
b. To the extent possible, if any provision of this Public License is
|
365 |
-
deemed unenforceable, it shall be automatically reformed to the
|
366 |
-
minimum extent necessary to make it enforceable. If the provision
|
367 |
-
cannot be reformed, it shall be severed from this Public License
|
368 |
-
without affecting the enforceability of the remaining terms and
|
369 |
-
conditions.
|
370 |
-
|
371 |
-
c. No term or condition of this Public License will be waived and no
|
372 |
-
failure to comply consented to unless expressly agreed to by the
|
373 |
-
Licensor.
|
374 |
-
|
375 |
-
d. Nothing in this Public License constitutes or may be interpreted
|
376 |
-
as a limitation upon, or waiver of, any privileges and immunities
|
377 |
-
that apply to the Licensor or You, including from the legal
|
378 |
-
processes of any jurisdiction or authority.
|
379 |
-
|
380 |
-
=======================================================================
|
381 |
-
|
382 |
-
Creative Commons is not a party to its public
|
383 |
-
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
384 |
-
its public licenses to material it publishes and in those instances
|
385 |
-
will be considered the “Licensor.” The text of the Creative Commons
|
386 |
-
public licenses is dedicated to the public domain under the CC0 Public
|
387 |
-
Domain Dedication. Except for the limited purpose of indicating that
|
388 |
-
material is shared under a Creative Commons public license or as
|
389 |
-
otherwise permitted by the Creative Commons policies published at
|
390 |
-
creativecommons.org/policies, Creative Commons does not authorize the
|
391 |
-
use of the trademark "Creative Commons" or any other trademark or logo
|
392 |
-
of Creative Commons without its prior written consent including,
|
393 |
-
without limitation, in connection with any unauthorized modifications
|
394 |
-
to any of its public licenses or any other arrangements,
|
395 |
-
understandings, or agreements concerning use of licensed material. For
|
396 |
-
the avoidance of doubt, this paragraph does not form part of the
|
397 |
-
public licenses.
|
398 |
-
|
399 |
-
Creative Commons may be contacted at creativecommons.org.
|
|
|
1 |
+
# Attribution-NonCommercial-NoDerivatives 4.0 International
|
2 |
+
|
3 |
+
> *Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.*
|
4 |
+
>
|
5 |
+
> ### Using Creative Commons Public Licenses
|
6 |
+
>
|
7 |
+
> Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
|
8 |
+
>
|
9 |
+
> * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors).
|
10 |
+
>
|
11 |
+
> * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees).
|
12 |
+
|
13 |
+
## Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International Public License
|
14 |
+
|
15 |
+
By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
|
16 |
+
|
17 |
+
### Section 1 – Definitions.
|
18 |
+
|
19 |
+
a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
|
20 |
+
|
21 |
+
b. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
|
22 |
+
|
23 |
+
e. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
|
24 |
+
|
25 |
+
f. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
|
26 |
+
|
27 |
+
h. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
|
28 |
+
|
29 |
+
i. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
|
30 |
+
|
31 |
+
h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License.
|
32 |
+
|
33 |
+
i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
|
34 |
+
|
35 |
+
j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
|
36 |
+
|
37 |
+
k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
|
38 |
+
|
39 |
+
l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
|
40 |
+
|
41 |
+
### Section 2 – Scope.
|
42 |
+
|
43 |
+
a. ___License grant.___
|
44 |
+
|
45 |
+
1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
|
46 |
+
|
47 |
+
A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
|
48 |
+
|
49 |
+
B. produce and reproduce, but not Share, Adapted Material for NonCommercial purposes only.
|
50 |
+
|
51 |
+
2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
|
52 |
+
|
53 |
+
3. __Term.__ The term of this Public License is specified in Section 6(a).
|
54 |
+
|
55 |
+
4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
|
56 |
+
|
57 |
+
5. __Downstream recipients.__
|
58 |
+
|
59 |
+
A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
|
60 |
+
|
61 |
+
B. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
|
62 |
+
|
63 |
+
6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
|
64 |
+
|
65 |
+
b. ___Other rights.___
|
66 |
+
|
67 |
+
1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
|
68 |
+
|
69 |
+
2. Patent and trademark rights are not licensed under this Public License.
|
70 |
+
|
71 |
+
3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
|
72 |
+
|
73 |
+
### Section 3 – License Conditions.
|
74 |
+
|
75 |
+
Your exercise of the Licensed Rights is expressly made subject to the following conditions.
|
76 |
+
|
77 |
+
a. ___Attribution.___
|
78 |
+
|
79 |
+
1. If You Share the Licensed Material, You must:
|
80 |
+
|
81 |
+
A. retain the following if it is supplied by the Licensor with the Licensed Material:
|
82 |
+
|
83 |
+
i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
|
84 |
+
|
85 |
+
ii. a copyright notice;
|
86 |
+
|
87 |
+
iii. a notice that refers to this Public License;
|
88 |
+
|
89 |
+
iv. a notice that refers to the disclaimer of warranties;
|
90 |
+
|
91 |
+
v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
|
92 |
+
|
93 |
+
B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
|
94 |
+
|
95 |
+
C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
|
96 |
+
|
97 |
+
For the avoidance of doubt, You do not have permission under this Public License to Share Adapted Material.
|
98 |
+
|
99 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
|
100 |
+
|
101 |
+
3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
|
102 |
+
|
103 |
+
### Section 4 – Sui Generis Database Rights.
|
104 |
+
|
105 |
+
Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
|
106 |
+
|
107 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only and provided You do not Share Adapted Material;
|
108 |
+
|
109 |
+
b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and
|
110 |
+
|
111 |
+
c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
|
112 |
+
|
113 |
+
For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
|
114 |
+
|
115 |
+
### Section 5 – Disclaimer of Warranties and Limitation of Liability.
|
116 |
+
|
117 |
+
a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__
|
118 |
+
|
119 |
+
b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__
|
120 |
+
|
121 |
+
c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
|
122 |
+
|
123 |
+
### Section 6 – Term and Termination.
|
124 |
+
|
125 |
+
a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
|
126 |
+
|
127 |
+
b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
|
128 |
+
|
129 |
+
1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
|
130 |
+
|
131 |
+
2. upon express reinstatement by the Licensor.
|
132 |
+
|
133 |
+
For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
|
134 |
+
|
135 |
+
c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
|
136 |
+
|
137 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
|
138 |
+
|
139 |
+
### Section 7 – Other Terms and Conditions.
|
140 |
+
|
141 |
+
a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
|
142 |
+
|
143 |
+
b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
|
144 |
+
|
145 |
+
### Section 8 – Interpretation.
|
146 |
+
|
147 |
+
a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
|
148 |
+
|
149 |
+
b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
|
150 |
+
|
151 |
+
c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
|
152 |
+
|
153 |
+
d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
|
154 |
+
|
155 |
+
> Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
|
156 |
+
>
|
157 |
+
> Creative Commons may be contacted at [creativecommons.org](http://creativecommons.org).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MANIFEST.in
CHANGED
@@ -6,10 +6,3 @@ include *.ini
|
|
6 |
include requirements.txt
|
7 |
include audiocraft/py.typed
|
8 |
include assets/*.mp3
|
9 |
-
include datasets/*.mp3
|
10 |
-
recursive-include config *.yaml
|
11 |
-
recursive-include demos *.py
|
12 |
-
recursive-include demos *.ipynb
|
13 |
-
recursive-include scripts *.py
|
14 |
-
recursive-include model_cards *.md
|
15 |
-
recursive-include docs *.md
|
|
|
6 |
include requirements.txt
|
7 |
include audiocraft/py.typed
|
8 |
include assets/*.mp3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_cards/MUSICGEN_MODEL_CARD.md → MODEL_CARD.md
RENAMED
@@ -12,11 +12,11 @@
|
|
12 |
|
13 |
**Paper or resources for more information:** More information can be found in the paper [Simple and Controllable Music Generation][arxiv].
|
14 |
|
15 |
-
**Citation details
|
16 |
|
17 |
-
**License
|
18 |
|
19 |
-
**Where to send questions or comments about the model:** Questions and comments about MusicGen can be sent via the [
|
20 |
|
21 |
## Intended use
|
22 |
**Primary intended use:** The primary use of MusicGen is research on AI-based music generation, including:
|
@@ -26,7 +26,7 @@
|
|
26 |
|
27 |
**Primary intended users:** The primary intended users of the model are researchers in audio, machine learning and artificial intelligence, as well as amateur seeking to better understand those models.
|
28 |
|
29 |
-
**Out-of-scope use cases
|
30 |
|
31 |
## Metrics
|
32 |
|
@@ -52,26 +52,17 @@ The model was evaluated on the [MusicCaps benchmark](https://www.kaggle.com/data
|
|
52 |
|
53 |
## Training datasets
|
54 |
|
55 |
-
The model was trained
|
56 |
|
57 |
-
##
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
| Model | Frechet Audio Distance | KLD | Text Consistency | Chroma Cosine Similarity |
|
62 |
-
|---|---|---|---|---|
|
63 |
-
| facebook/musicgen-small | 4.88 | 1.42 | 0.27 | - |
|
64 |
-
| facebook/musicgen-medium | 5.14 | 1.38 | 0.28 | - |
|
65 |
-
| facebook/musicgen-large | 5.48 | 1.37 | 0.28 | - |
|
66 |
-
| facebook/musicgen-melody | 4.93 | 1.41 | 0.27 | 0.44 |
|
67 |
-
|
68 |
-
More information can be found in the paper [Simple and Controllable Music Generation][arxiv], in the Results section.
|
69 |
|
70 |
## Limitations and biases
|
71 |
|
72 |
**Data:** The data sources used to train the model are created by music professionals and covered by legal agreements with the right holders. The model is trained on 20K hours of data, we believe that scaling the model on larger datasets can further improve the performance of the model.
|
73 |
|
74 |
-
**Mitigations:**
|
75 |
|
76 |
**Limitations:**
|
77 |
|
@@ -87,19 +78,4 @@ More information can be found in the paper [Simple and Controllable Music Genera
|
|
87 |
|
88 |
**Use cases:** Users must be aware of the biases, limitations and risks of the model. MusicGen is a model developed for artificial intelligence research on controllable music generation. As such, it should not be used for downstream applications without further investigation and mitigation of risks.
|
89 |
|
90 |
-
## Update: stereo models and large melody.
|
91 |
-
|
92 |
-
We further release a set of stereophonic capable models. Those were fine tuned for 200k updates starting
|
93 |
-
from the mono models. The training data is otherwise identical and capabilities and limitations are shared with the base modes. The stereo models work by getting 2 streams of tokens from the EnCodec model, and interleaving those using
|
94 |
-
the delay pattern. We also release a mono large model with melody conditioning capabilities. The list of new models
|
95 |
-
is as follow:
|
96 |
-
|
97 |
-
- facebook/musicgen-stereo-small
|
98 |
-
- facebook/musicgen-stereo-medium
|
99 |
-
- facebook/musicgen-stereo-large
|
100 |
-
- facebook/musicgen-stereo-melody
|
101 |
-
- facebook/musicgen-melody-large
|
102 |
-
- facebook/musicgen-stereo-melody-large
|
103 |
-
|
104 |
-
|
105 |
[arxiv]: https://arxiv.org/abs/2306.05284
|
|
|
12 |
|
13 |
**Paper or resources for more information:** More information can be found in the paper [Simple and Controllable Music Generation][arxiv].
|
14 |
|
15 |
+
**Citation details** See [our paper][arxiv]
|
16 |
|
17 |
+
**License** Code is released under MIT, model weights are released under CC-BY-NC 4.0.
|
18 |
|
19 |
+
**Where to send questions or comments about the model:** Questions and comments about MusicGen can be sent via the [Github repository](https://github.com/facebookresearch/audiocraft) of the project, or by opening an issue.
|
20 |
|
21 |
## Intended use
|
22 |
**Primary intended use:** The primary use of MusicGen is research on AI-based music generation, including:
|
|
|
26 |
|
27 |
**Primary intended users:** The primary intended users of the model are researchers in audio, machine learning and artificial intelligence, as well as amateur seeking to better understand those models.
|
28 |
|
29 |
+
**Out-of-scope use cases** The model should not be used on downstream applications without further risk evaluation and mitigation. The model should not be used to intentionally create or disseminate music pieces that create hostile or alienating environments for people. This includes generating music that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
|
30 |
|
31 |
## Metrics
|
32 |
|
|
|
52 |
|
53 |
## Training datasets
|
54 |
|
55 |
+
The model was trained using the following sources: the [Meta Music Initiative Sound Collection](https://www.fb.com/sound), [Shutterstock music collection](https://www.shutterstock.com/music) and the [Pond5 music collection](https://www.pond5.com/). See the paper for more details about the training set and corresponding preprocessing.
|
56 |
|
57 |
+
## Quantitative analysis
|
58 |
|
59 |
+
More information can be found in the paper [Simple and Controllable Music Generation][arxiv], in the Experimental Setup section.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
## Limitations and biases
|
62 |
|
63 |
**Data:** The data sources used to train the model are created by music professionals and covered by legal agreements with the right holders. The model is trained on 20K hours of data, we believe that scaling the model on larger datasets can further improve the performance of the model.
|
64 |
|
65 |
+
**Mitigations:** All vocals have been removed from the data source using a state-of-the-art music source separation method, namely using the open source [Hybrid Transformer for Music Source Separation](https://github.com/facebookresearch/demucs) (HT-Demucs). The model is therefore not able to produce vocals.
|
66 |
|
67 |
**Limitations:**
|
68 |
|
|
|
78 |
|
79 |
**Use cases:** Users must be aware of the biases, limitations and risks of the model. MusicGen is a model developed for artificial intelligence research on controllable music generation. As such, it should not be used for downstream applications without further investigation and mitigation of risks.
|
80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
[arxiv]: https://arxiv.org/abs/2306.05284
|
Makefile
CHANGED
@@ -1,15 +1,3 @@
|
|
1 |
-
INTEG=AUDIOCRAFT_DORA_DIR="/tmp/magma_$(USER)" python3 -m dora -v run --clear device=cpu dataset.num_workers=0 optim.epochs=1 \
|
2 |
-
dataset.train.num_samples=10 dataset.valid.num_samples=10 \
|
3 |
-
dataset.evaluate.num_samples=10 dataset.generate.num_samples=2 sample_rate=16000 \
|
4 |
-
logging.level=DEBUG
|
5 |
-
INTEG_COMPRESSION = $(INTEG) solver=compression/debug rvq.n_q=2 rvq.bins=48 checkpoint.save_last=true # SIG is 5091833e
|
6 |
-
INTEG_MUSICGEN = $(INTEG) solver=musicgen/debug dset=audio/example compression_model_checkpoint=//sig/5091833e \
|
7 |
-
transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false # Using compression model from 5091833e
|
8 |
-
INTEG_AUDIOGEN = $(INTEG) solver=audiogen/debug dset=audio/example compression_model_checkpoint=//sig/5091833e \
|
9 |
-
transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false # Using compression model from 5091833e
|
10 |
-
INTEG_MBD = $(INTEG) solver=diffusion/debug dset=audio/example \
|
11 |
-
checkpoint.save_last=false # Using compression model from 616d7b3c
|
12 |
-
|
13 |
default: linter tests
|
14 |
|
15 |
install:
|
@@ -22,19 +10,12 @@ linter:
|
|
22 |
|
23 |
tests:
|
24 |
coverage run -m pytest tests
|
25 |
-
coverage report
|
26 |
-
|
27 |
-
tests_integ:
|
28 |
-
$(INTEG_COMPRESSION)
|
29 |
-
$(INTEG_MBD)
|
30 |
-
$(INTEG_MUSICGEN)
|
31 |
-
$(INTEG_AUDIOGEN)
|
32 |
-
|
33 |
|
34 |
-
|
35 |
-
pdoc3 --html -o
|
36 |
|
37 |
dist:
|
38 |
python setup.py sdist
|
39 |
|
40 |
-
.PHONY: linter tests
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
default: linter tests
|
2 |
|
3 |
install:
|
|
|
10 |
|
11 |
tests:
|
12 |
coverage run -m pytest tests
|
13 |
+
coverage report --include 'audiocraft/*'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
+
docs:
|
16 |
+
pdoc3 --html -o docs -f audiocraft
|
17 |
|
18 |
dist:
|
19 |
python setup.py sdist
|
20 |
|
21 |
+
.PHONY: linter tests docs dist
|
README.md
CHANGED
@@ -5,27 +5,40 @@ tags:
|
|
5 |
- "music generation"
|
6 |
- "language models"
|
7 |
- "LLMs"
|
8 |
-
app_file: "
|
9 |
emoji: 🎵
|
10 |
-
colorFrom:
|
11 |
colorTo: blue
|
12 |
sdk: gradio
|
13 |
sdk_version: 3.34.0
|
14 |
pinned: true
|
15 |
license: "cc-by-nc-4.0"
|
16 |
-
disable_embedding: true
|
17 |
---
|
18 |
-
#
|
19 |
![docs badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_docs/badge.svg)
|
20 |
![linter badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_linter/badge.svg)
|
21 |
![tests badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_tests/badge.svg)
|
22 |
|
23 |
-
|
24 |
-
for two state-of-the-art AI generative models producing high-quality audio: AudioGen and MusicGen.
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
## Installation
|
28 |
-
|
29 |
|
30 |
```shell
|
31 |
# Best to make sure you have torch installed first, in particular before installing xformers.
|
@@ -34,68 +47,71 @@ pip install 'torch>=2.0'
|
|
34 |
# Then proceed to one of the following
|
35 |
pip install -U audiocraft # stable release
|
36 |
pip install -U git+https://[email protected]/facebookresearch/audiocraft#egg=audiocraft # bleeding edge
|
37 |
-
pip install -e . # or if you cloned the repo locally
|
38 |
-
```
|
39 |
-
|
40 |
-
We also recommend having `ffmpeg` installed, either through your system or Anaconda:
|
41 |
-
```bash
|
42 |
-
sudo apt-get install ffmpeg
|
43 |
-
# Or if you are using Anaconda or Miniconda
|
44 |
-
conda install "ffmpeg<5" -c conda-forge
|
45 |
```
|
46 |
|
47 |
-
##
|
|
|
48 |
|
49 |
-
|
50 |
-
* [MusicGen](./docs/MUSICGEN.md): A state-of-the-art controllable text-to-music model.
|
51 |
-
* [AudioGen](./docs/AUDIOGEN.md): A state-of-the-art text-to-sound model.
|
52 |
-
* [EnCodec](./docs/ENCODEC.md): A state-of-the-art high fidelity neural audio codec.
|
53 |
-
* [Multi Band Diffusion](./docs/MBD.md): An EnCodec compatible decoder using diffusion.
|
54 |
|
55 |
-
|
|
|
|
|
|
|
|
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
the
|
60 |
|
61 |
-
|
62 |
-
that provides pointers to configuration, example grids and model/task-specific information and FAQ.
|
63 |
|
|
|
|
|
|
|
|
|
64 |
|
65 |
-
|
|
|
|
|
|
|
|
|
66 |
|
67 |
-
|
|
|
|
|
68 |
|
|
|
|
|
|
|
|
|
69 |
|
70 |
-
## FAQ
|
71 |
-
|
72 |
-
#### Is the training code available?
|
73 |
|
74 |
-
|
75 |
|
76 |
-
|
77 |
|
78 |
-
|
79 |
-
In order to change the cache location of the other Hugging Face models, please check out the [Hugging Face Transformers documentation for the cache setup](https://huggingface.co/docs/transformers/installation#cache-setup).
|
80 |
-
Finally, if you use a model that relies on Demucs (e.g. `musicgen-melody`) and want to change the download location for Demucs, refer to the [Torch Hub documentation](https://pytorch.org/docs/stable/hub.html#where-are-my-downloaded-models-saved).
|
81 |
|
|
|
82 |
|
83 |
-
|
84 |
-
* The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE).
|
85 |
-
* The models weights in this repository are released under the CC-BY-NC 4.0 license as found in the [LICENSE_weights file](LICENSE_weights).
|
86 |
|
87 |
|
88 |
## Citation
|
89 |
-
|
90 |
-
For the general framework of AudioCraft, please cite the following.
|
91 |
```
|
92 |
@article{copet2023simple,
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
}
|
98 |
```
|
99 |
|
100 |
-
|
101 |
-
[
|
|
|
|
|
|
|
|
|
|
5 |
- "music generation"
|
6 |
- "language models"
|
7 |
- "LLMs"
|
8 |
+
app_file: "app_batched.py"
|
9 |
emoji: 🎵
|
10 |
+
colorFrom: white
|
11 |
colorTo: blue
|
12 |
sdk: gradio
|
13 |
sdk_version: 3.34.0
|
14 |
pinned: true
|
15 |
license: "cc-by-nc-4.0"
|
|
|
16 |
---
|
17 |
+
# Audiocraft
|
18 |
![docs badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_docs/badge.svg)
|
19 |
![linter badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_linter/badge.svg)
|
20 |
![tests badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_tests/badge.svg)
|
21 |
|
22 |
+
Audiocraft is a PyTorch library for deep learning research on audio generation. At the moment, it contains the code for MusicGen, a state-of-the-art controllable text-to-music model.
|
|
|
23 |
|
24 |
+
## MusicGen
|
25 |
+
|
26 |
+
Audiocraft provides the code and models for MusicGen, [a simple and controllable model for music generation][arxiv]. MusicGen is a single stage auto-regressive
|
27 |
+
Transformer model trained over a 32kHz <a href="https://github.com/facebookresearch/encodec">EnCodec tokenizer</a> with 4 codebooks sampled at 50 Hz. Unlike existing methods like [MusicLM](https://arxiv.org/abs/2301.11325), MusicGen doesn't not require a self-supervised semantic representation, and it generates
|
28 |
+
all 4 codebooks in one pass. By introducing a small delay between the codebooks, we show we can predict
|
29 |
+
them in parallel, thus having only 50 auto-regressive steps per second of audio.
|
30 |
+
Check out our [sample page][musicgen_samples] or test the available demo!
|
31 |
+
|
32 |
+
<a target="_blank" href="https://colab.research.google.com/drive/1fxGqfg96RBUvGxZ1XXN07s3DthrKUl4-?usp=sharing">
|
33 |
+
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
34 |
+
</a>
|
35 |
+
<a target="_blank" href="https://huggingface.co/spaces/facebook/MusicGen">
|
36 |
+
<img src="https://huggingface.co/datasets/huggingface/badges/raw/main/open-in-hf-spaces-sm.svg" alt="Open in HugginFace"/>
|
37 |
+
</a>
|
38 |
+
<br>
|
39 |
|
40 |
## Installation
|
41 |
+
Audiocraft requires Python 3.9, PyTorch 2.0.0, and a GPU with at least 16 GB of memory (for the medium-sized model). To install Audiocraft, you can run the following:
|
42 |
|
43 |
```shell
|
44 |
# Best to make sure you have torch installed first, in particular before installing xformers.
|
|
|
47 |
# Then proceed to one of the following
|
48 |
pip install -U audiocraft # stable release
|
49 |
pip install -U git+https://[email protected]/facebookresearch/audiocraft#egg=audiocraft # bleeding edge
|
50 |
+
pip install -e . # or if you cloned the repo locally
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
```
|
52 |
|
53 |
+
## Usage
|
54 |
+
You can play with MusicGen by running the jupyter notebook at [`demo.ipynb`](./demo.ipynb) locally, or use the provided [colab notebook](https://colab.research.google.com/drive/1fxGqfg96RBUvGxZ1XXN07s3DthrKUl4-?usp=sharing). Finally, a demo is also available on the [`facebook/MusiGen` HugginFace Space](https://huggingface.co/spaces/facebook/MusicGen) (huge thanks to all the HF team for their support).
|
55 |
|
56 |
+
## API
|
|
|
|
|
|
|
|
|
57 |
|
58 |
+
We provide a simple API and 4 pre-trained models. The pre trained models are:
|
59 |
+
- `small`: 300M model, text to music only,
|
60 |
+
- `medium`: 1.5B model, text to music only,
|
61 |
+
- `melody`: 1.5B model, text to music and text+melody to music,
|
62 |
+
- `large`: 3.3B model, text to music only.
|
63 |
|
64 |
+
We observe the best trade-off between quality and compute with the `medium` or `melody` model.
|
65 |
+
In order to use MusicGen locally **you must have a GPU**. We recommend 16GB of memory, but smaller
|
66 |
+
GPUs will be able to generate short sequences, or longer sequences with the `small` model.
|
67 |
|
68 |
+
See after a quick example for using the API.
|
|
|
69 |
|
70 |
+
```python
|
71 |
+
import torchaudio
|
72 |
+
from audiocraft.models import MusicGen
|
73 |
+
from audiocraft.data.audio import audio_write
|
74 |
|
75 |
+
model = MusicGen.get_pretrained('melody')
|
76 |
+
model.set_generation_params(duration=8) # generate 8 seconds.
|
77 |
+
wav = model.generate_unconditional(4) # generates 4 unconditional audio samples
|
78 |
+
descriptions = ['happy rock', 'energetic EDM', 'sad jazz']
|
79 |
+
wav = model.generate(descriptions) # generates 3 samples.
|
80 |
|
81 |
+
melody, sr = torchaudio.load('./assets/bach.mp3')
|
82 |
+
# generates using the melody from the given audio and the provided descriptions.
|
83 |
+
wav = model.generate_with_chroma(descriptions, melody[None].expand(3, -1, -1), sr)
|
84 |
|
85 |
+
for idx, one_wav in enumerate(wav):
|
86 |
+
# Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
|
87 |
+
audio_write(f'{idx}', one_wav, model.sample_rate, strategy="loudness")
|
88 |
+
```
|
89 |
|
|
|
|
|
|
|
90 |
|
91 |
+
## Model Card
|
92 |
|
93 |
+
See [the model card page](./MODEL_CARD.md).
|
94 |
|
95 |
+
## FAQ
|
|
|
|
|
96 |
|
97 |
+
#### Will the training code be released?
|
98 |
|
99 |
+
Yes. We will soon release the training code for MusicGen and EnCodec.
|
|
|
|
|
100 |
|
101 |
|
102 |
## Citation
|
|
|
|
|
103 |
```
|
104 |
@article{copet2023simple,
|
105 |
+
title={Simple and Controllable Music Generation},
|
106 |
+
author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez},
|
107 |
+
year={2023},
|
108 |
+
journal={arXiv preprint arXiv:2306.05284},
|
109 |
}
|
110 |
```
|
111 |
|
112 |
+
## License
|
113 |
+
* The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE).
|
114 |
+
* The weights in this repository are released under the CC-BY-NC 4.0 license as found in the [LICENSE_weights file](LICENSE_weights).
|
115 |
+
|
116 |
+
[arxiv]: https://arxiv.org/abs/2306.05284
|
117 |
+
[musicgen_samples]: https://ai.honu.io/papers/musicgen/
|
app.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
All rights reserved.
|
4 |
+
|
5 |
+
This source code is licensed under the license found in the
|
6 |
+
LICENSE file in the root directory of this source tree.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import gradio as gr
|
11 |
+
from hf_loading import get_pretrained
|
12 |
+
|
13 |
+
|
14 |
+
MODEL = None
|
15 |
+
|
16 |
+
|
17 |
+
def load_model(version):
|
18 |
+
print("Loading model", version)
|
19 |
+
return get_pretrained(version)
|
20 |
+
|
21 |
+
|
22 |
+
def predict(model, text, melody, duration, topk, topp, temperature, cfg_coef):
|
23 |
+
global MODEL
|
24 |
+
topk = int(topk)
|
25 |
+
if MODEL is None or MODEL.name != model:
|
26 |
+
MODEL = load_model(model)
|
27 |
+
|
28 |
+
if duration > MODEL.lm.cfg.dataset.segment_duration:
|
29 |
+
raise gr.Error("MusicGen currently supports durations of up to 30 seconds!")
|
30 |
+
MODEL.set_generation_params(
|
31 |
+
use_sampling=True,
|
32 |
+
top_k=topk,
|
33 |
+
top_p=topp,
|
34 |
+
temperature=temperature,
|
35 |
+
cfg_coef=cfg_coef,
|
36 |
+
duration=duration,
|
37 |
+
)
|
38 |
+
|
39 |
+
if melody:
|
40 |
+
sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t().unsqueeze(0)
|
41 |
+
print(melody.shape)
|
42 |
+
if melody.dim() == 2:
|
43 |
+
melody = melody[None]
|
44 |
+
melody = melody[..., :int(sr * MODEL.lm.cfg.dataset.segment_duration)]
|
45 |
+
output = MODEL.generate_with_chroma(
|
46 |
+
descriptions=[text],
|
47 |
+
melody_wavs=melody,
|
48 |
+
melody_sample_rate=sr,
|
49 |
+
progress=False
|
50 |
+
)
|
51 |
+
else:
|
52 |
+
output = MODEL.generate(descriptions=[text], progress=False)
|
53 |
+
|
54 |
+
output = output.detach().cpu().numpy()
|
55 |
+
return MODEL.sample_rate, output
|
56 |
+
|
57 |
+
|
58 |
+
with gr.Blocks() as demo:
|
59 |
+
gr.Markdown(
|
60 |
+
"""
|
61 |
+
# MusicGen
|
62 |
+
|
63 |
+
This is the demo for MusicGen, a simple and controllable model for music generation presented at: "Simple and Controllable Music Generation".
|
64 |
+
|
65 |
+
Below we present 3 model variations:
|
66 |
+
1. Melody -- a music generation model capable of generating music condition on text and melody inputs. **Note**, you can also use text only.
|
67 |
+
2. Small -- a 300M transformer decoder conditioned on text only.
|
68 |
+
3. Medium -- a 1.5B transformer decoder conditioned on text only.
|
69 |
+
4. Large -- a 3.3B transformer decoder conditioned on text only (might OOM for the longest sequences.)
|
70 |
+
|
71 |
+
When the optional melody conditioning wav is provided, the model will extract
|
72 |
+
a broad melody and try to follow it in the generated samples.
|
73 |
+
|
74 |
+
For skipping queue, you can duplicate this space, and upgrade to GPU in the settings.
|
75 |
+
<br/>
|
76 |
+
<a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true">
|
77 |
+
<img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
|
78 |
+
</p>
|
79 |
+
|
80 |
+
See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
|
81 |
+
for more details.
|
82 |
+
"""
|
83 |
+
)
|
84 |
+
with gr.Row():
|
85 |
+
with gr.Column():
|
86 |
+
with gr.Row():
|
87 |
+
text = gr.Text(label="Input Text", interactive=True)
|
88 |
+
melody = gr.Audio(source="upload", type="numpy", label="Melody Condition (optional)", interactive=True)
|
89 |
+
with gr.Row():
|
90 |
+
submit = gr.Button("Submit")
|
91 |
+
with gr.Row():
|
92 |
+
model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
|
93 |
+
with gr.Row():
|
94 |
+
duration = gr.Slider(minimum=1, maximum=30, value=10, label="Duration", interactive=True)
|
95 |
+
with gr.Row():
|
96 |
+
topk = gr.Number(label="Top-k", value=250, interactive=True)
|
97 |
+
topp = gr.Number(label="Top-p", value=0, interactive=True)
|
98 |
+
temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
|
99 |
+
cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
|
100 |
+
with gr.Column():
|
101 |
+
output = gr.Audio(label="Generated Music", type="numpy")
|
102 |
+
submit.click(predict, inputs=[model, text, melody, duration, topk, topp, temperature, cfg_coef], outputs=[output])
|
103 |
+
gr.Examples(
|
104 |
+
fn=predict,
|
105 |
+
examples=[
|
106 |
+
[
|
107 |
+
"An 80s driving pop song with heavy drums and synth pads in the background",
|
108 |
+
"./assets/bach.mp3",
|
109 |
+
"melody"
|
110 |
+
],
|
111 |
+
[
|
112 |
+
"A cheerful country song with acoustic guitars",
|
113 |
+
"./assets/bolero_ravel.mp3",
|
114 |
+
"melody"
|
115 |
+
],
|
116 |
+
[
|
117 |
+
"90s rock song with electric guitar and heavy drums",
|
118 |
+
None,
|
119 |
+
"medium"
|
120 |
+
],
|
121 |
+
[
|
122 |
+
"a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
|
123 |
+
"./assets/bach.mp3",
|
124 |
+
"melody"
|
125 |
+
],
|
126 |
+
[
|
127 |
+
"lofi slow bpm electro chill with organic samples",
|
128 |
+
None,
|
129 |
+
"medium",
|
130 |
+
],
|
131 |
+
],
|
132 |
+
inputs=[text, melody, model],
|
133 |
+
outputs=[output]
|
134 |
+
)
|
135 |
+
|
136 |
+
demo.launch()
|
app_batched.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
All rights reserved.
|
4 |
+
|
5 |
+
This source code is licensed under the license found in the
|
6 |
+
LICENSE file in the root directory of this source tree.
|
7 |
+
"""
|
8 |
+
|
9 |
+
from tempfile import NamedTemporaryFile
|
10 |
+
import torch
|
11 |
+
import gradio as gr
|
12 |
+
from audiocraft.data.audio_utils import convert_audio
|
13 |
+
from audiocraft.data.audio import audio_write
|
14 |
+
from hf_loading import get_pretrained
|
15 |
+
|
16 |
+
|
17 |
+
MODEL = None
|
18 |
+
|
19 |
+
|
20 |
+
def load_model():
|
21 |
+
print("Loading model")
|
22 |
+
return get_pretrained("melody")
|
23 |
+
|
24 |
+
|
25 |
+
def predict(texts, melodies):
|
26 |
+
global MODEL
|
27 |
+
if MODEL is None:
|
28 |
+
MODEL = load_model()
|
29 |
+
|
30 |
+
duration = 12
|
31 |
+
MODEL.set_generation_params(duration=duration)
|
32 |
+
|
33 |
+
print(texts, melodies)
|
34 |
+
processed_melodies = []
|
35 |
+
|
36 |
+
target_sr = 32000
|
37 |
+
target_ac = 1
|
38 |
+
for melody in melodies:
|
39 |
+
if melody is None:
|
40 |
+
processed_melodies.append(None)
|
41 |
+
else:
|
42 |
+
sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
|
43 |
+
if melody.dim() == 1:
|
44 |
+
melody = melody[None]
|
45 |
+
melody = melody[..., :int(sr * duration)]
|
46 |
+
melody = convert_audio(melody, sr, target_sr, target_ac)
|
47 |
+
processed_melodies.append(melody)
|
48 |
+
|
49 |
+
outputs = MODEL.generate_with_chroma(
|
50 |
+
descriptions=texts,
|
51 |
+
melody_wavs=processed_melodies,
|
52 |
+
melody_sample_rate=target_sr,
|
53 |
+
progress=False
|
54 |
+
)
|
55 |
+
|
56 |
+
outputs = outputs.detach().cpu().float()
|
57 |
+
out_files = []
|
58 |
+
for output in outputs:
|
59 |
+
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
60 |
+
audio_write(file.name, output, MODEL.sample_rate, strategy="loudness", add_suffix=False)
|
61 |
+
waveform_video = gr.make_waveform(file.name)
|
62 |
+
out_files.append(waveform_video)
|
63 |
+
print(out_files)
|
64 |
+
return [out_files]
|
65 |
+
|
66 |
+
|
67 |
+
with gr.Blocks() as demo:
|
68 |
+
gr.Markdown(
|
69 |
+
"""
|
70 |
+
# MusicGen
|
71 |
+
|
72 |
+
This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
|
73 |
+
presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
|
74 |
+
<br/>
|
75 |
+
<a href="https://huggingface.co/spaces/facebook/MusicGen?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
|
76 |
+
<img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
|
77 |
+
for longer sequences, more control and no queue</p>
|
78 |
+
"""
|
79 |
+
)
|
80 |
+
with gr.Row():
|
81 |
+
with gr.Column():
|
82 |
+
with gr.Row():
|
83 |
+
text = gr.Text(label="Describe your music", lines=2, interactive=True)
|
84 |
+
melody = gr.Audio(source="upload", type="numpy", label="Condition on a melody (optional)", interactive=True)
|
85 |
+
with gr.Row():
|
86 |
+
submit = gr.Button("Generate")
|
87 |
+
with gr.Column():
|
88 |
+
output = gr.Video(label="Generated Music")
|
89 |
+
submit.click(predict, inputs=[text, melody], outputs=[output], batch=True, max_batch_size=12)
|
90 |
+
gr.Examples(
|
91 |
+
fn=predict,
|
92 |
+
examples=[
|
93 |
+
[
|
94 |
+
"An 80s driving pop song with heavy drums and synth pads in the background",
|
95 |
+
"./assets/bach.mp3",
|
96 |
+
],
|
97 |
+
[
|
98 |
+
"A cheerful country song with acoustic guitars",
|
99 |
+
"./assets/bolero_ravel.mp3",
|
100 |
+
],
|
101 |
+
[
|
102 |
+
"90s rock song with electric guitar and heavy drums",
|
103 |
+
None,
|
104 |
+
],
|
105 |
+
[
|
106 |
+
"a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
|
107 |
+
"./assets/bach.mp3",
|
108 |
+
],
|
109 |
+
[
|
110 |
+
"lofi slow bpm electro chill with organic samples",
|
111 |
+
None,
|
112 |
+
],
|
113 |
+
],
|
114 |
+
inputs=[text, melody],
|
115 |
+
outputs=[output]
|
116 |
+
)
|
117 |
+
gr.Markdown("""
|
118 |
+
### More details
|
119 |
+
By typing a description of the music you want and an optional audio used for melody conditioning,
|
120 |
+
the model will extract the broad melody from the uploaded wav if provided and generate a 12s extract with the `melody` model.
|
121 |
+
|
122 |
+
You can also use your own GPU or a Google Colab by following the instructions on our repo.
|
123 |
+
|
124 |
+
See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
|
125 |
+
for more details.
|
126 |
+
""")
|
127 |
+
|
128 |
+
demo.queue(max_size=15).launch()
|
assets/a_duck_quacking_as_birds_chirp_and_a_pigeon_cooing.mp3
DELETED
Binary file (15.2 kB)
|
|
assets/sirens_and_a_humming_engine_approach_and_pass.mp3
DELETED
Binary file (15.2 kB)
|
|
audiocraft/__init__.py
CHANGED
@@ -3,24 +3,8 @@
|
|
3 |
#
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
-
"""
|
7 |
-
AudioCraft is a general framework for training audio generative models.
|
8 |
-
At the moment we provide the training code for:
|
9 |
-
|
10 |
-
- [MusicGen](https://arxiv.org/abs/2306.05284), a state-of-the-art
|
11 |
-
text-to-music and melody+text autoregressive generative model.
|
12 |
-
For the solver, see `audiocraft.solvers.musicgen.MusicGenSolver`, and for the model,
|
13 |
-
`audiocraft.models.musicgen.MusicGen`.
|
14 |
-
- [AudioGen](https://arxiv.org/abs/2209.15352), a state-of-the-art
|
15 |
-
text-to-general-audio generative model.
|
16 |
-
- [EnCodec](https://arxiv.org/abs/2210.13438), efficient and high fidelity
|
17 |
-
neural audio codec which provides an excellent tokenizer for autoregressive language models.
|
18 |
-
See `audiocraft.solvers.compression.CompressionSolver`, and `audiocraft.models.encodec.EncodecModel`.
|
19 |
-
- [MultiBandDiffusion](TODO), alternative diffusion-based decoder compatible with EnCodec that
|
20 |
-
improves the perceived quality and reduces the artifacts coming from adversarial decoders.
|
21 |
-
"""
|
22 |
|
23 |
# flake8: noqa
|
24 |
from . import data, modules, models
|
25 |
|
26 |
-
__version__ = '
|
|
|
3 |
#
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
# flake8: noqa
|
8 |
from . import data, modules, models
|
9 |
|
10 |
+
__version__ = '0.0.1'
|
audiocraft/adversarial/__init__.py
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
"""Adversarial losses and discriminator architectures."""
|
7 |
-
|
8 |
-
# flake8: noqa
|
9 |
-
from .discriminators import (
|
10 |
-
MultiPeriodDiscriminator,
|
11 |
-
MultiScaleDiscriminator,
|
12 |
-
MultiScaleSTFTDiscriminator
|
13 |
-
)
|
14 |
-
from .losses import (
|
15 |
-
AdversarialLoss,
|
16 |
-
AdvLossType,
|
17 |
-
get_adv_criterion,
|
18 |
-
get_fake_criterion,
|
19 |
-
get_real_criterion,
|
20 |
-
FeatLossType,
|
21 |
-
FeatureMatchingLoss
|
22 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/adversarial/discriminators/__init__.py
DELETED
@@ -1,10 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
# flake8: noqa
|
8 |
-
from .mpd import MultiPeriodDiscriminator
|
9 |
-
from .msd import MultiScaleDiscriminator
|
10 |
-
from .msstftd import MultiScaleSTFTDiscriminator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/adversarial/discriminators/base.py
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
from abc import ABC, abstractmethod
|
8 |
-
import typing as tp
|
9 |
-
|
10 |
-
import torch
|
11 |
-
import torch.nn as nn
|
12 |
-
|
13 |
-
|
14 |
-
FeatureMapType = tp.List[torch.Tensor]
|
15 |
-
LogitsType = torch.Tensor
|
16 |
-
MultiDiscriminatorOutputType = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]]
|
17 |
-
|
18 |
-
|
19 |
-
class MultiDiscriminator(ABC, nn.Module):
|
20 |
-
"""Base implementation for discriminators composed of sub-discriminators acting at different scales.
|
21 |
-
"""
|
22 |
-
def __init__(self):
|
23 |
-
super().__init__()
|
24 |
-
|
25 |
-
@abstractmethod
|
26 |
-
def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
|
27 |
-
...
|
28 |
-
|
29 |
-
@property
|
30 |
-
@abstractmethod
|
31 |
-
def num_discriminators(self) -> int:
|
32 |
-
"""Number of discriminators.
|
33 |
-
"""
|
34 |
-
...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/adversarial/discriminators/mpd.py
DELETED
@@ -1,106 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import typing as tp
|
8 |
-
|
9 |
-
import torch
|
10 |
-
import torch.nn as nn
|
11 |
-
import torch.nn.functional as F
|
12 |
-
|
13 |
-
from ...modules import NormConv2d
|
14 |
-
from .base import MultiDiscriminator, MultiDiscriminatorOutputType
|
15 |
-
|
16 |
-
|
17 |
-
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
18 |
-
return int((kernel_size * dilation - dilation) / 2)
|
19 |
-
|
20 |
-
|
21 |
-
class PeriodDiscriminator(nn.Module):
|
22 |
-
"""Period sub-discriminator.
|
23 |
-
|
24 |
-
Args:
|
25 |
-
period (int): Period between samples of audio.
|
26 |
-
in_channels (int): Number of input channels.
|
27 |
-
out_channels (int): Number of output channels.
|
28 |
-
n_layers (int): Number of convolutional layers.
|
29 |
-
kernel_sizes (list of int): Kernel sizes for convolutions.
|
30 |
-
stride (int): Stride for convolutions.
|
31 |
-
filters (int): Initial number of filters in convolutions.
|
32 |
-
filters_scale (int): Multiplier of number of filters as we increase depth.
|
33 |
-
max_filters (int): Maximum number of filters.
|
34 |
-
norm (str): Normalization method.
|
35 |
-
activation (str): Activation function.
|
36 |
-
activation_params (dict): Parameters to provide to the activation function.
|
37 |
-
"""
|
38 |
-
def __init__(self, period: int, in_channels: int = 1, out_channels: int = 1,
|
39 |
-
n_layers: int = 5, kernel_sizes: tp.List[int] = [5, 3], stride: int = 3,
|
40 |
-
filters: int = 8, filters_scale: int = 4, max_filters: int = 1024,
|
41 |
-
norm: str = 'weight_norm', activation: str = 'LeakyReLU',
|
42 |
-
activation_params: dict = {'negative_slope': 0.2}):
|
43 |
-
super().__init__()
|
44 |
-
self.period = period
|
45 |
-
self.n_layers = n_layers
|
46 |
-
self.activation = getattr(torch.nn, activation)(**activation_params)
|
47 |
-
self.convs = nn.ModuleList()
|
48 |
-
in_chs = in_channels
|
49 |
-
for i in range(self.n_layers):
|
50 |
-
out_chs = min(filters * (filters_scale ** (i + 1)), max_filters)
|
51 |
-
eff_stride = 1 if i == self.n_layers - 1 else stride
|
52 |
-
self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_sizes[0], 1), stride=(eff_stride, 1),
|
53 |
-
padding=((kernel_sizes[0] - 1) // 2, 0), norm=norm))
|
54 |
-
in_chs = out_chs
|
55 |
-
self.conv_post = NormConv2d(in_chs, out_channels, kernel_size=(kernel_sizes[1], 1), stride=1,
|
56 |
-
padding=((kernel_sizes[1] - 1) // 2, 0), norm=norm)
|
57 |
-
|
58 |
-
def forward(self, x: torch.Tensor):
|
59 |
-
fmap = []
|
60 |
-
# 1d to 2d
|
61 |
-
b, c, t = x.shape
|
62 |
-
if t % self.period != 0: # pad first
|
63 |
-
n_pad = self.period - (t % self.period)
|
64 |
-
x = F.pad(x, (0, n_pad), 'reflect')
|
65 |
-
t = t + n_pad
|
66 |
-
x = x.view(b, c, t // self.period, self.period)
|
67 |
-
|
68 |
-
for conv in self.convs:
|
69 |
-
x = conv(x)
|
70 |
-
x = self.activation(x)
|
71 |
-
fmap.append(x)
|
72 |
-
x = self.conv_post(x)
|
73 |
-
fmap.append(x)
|
74 |
-
# x = torch.flatten(x, 1, -1)
|
75 |
-
|
76 |
-
return x, fmap
|
77 |
-
|
78 |
-
|
79 |
-
class MultiPeriodDiscriminator(MultiDiscriminator):
|
80 |
-
"""Multi-Period (MPD) Discriminator.
|
81 |
-
|
82 |
-
Args:
|
83 |
-
in_channels (int): Number of input channels.
|
84 |
-
out_channels (int): Number of output channels.
|
85 |
-
periods (Sequence[int]): Periods between samples of audio for the sub-discriminators.
|
86 |
-
**kwargs: Additional args for `PeriodDiscriminator`
|
87 |
-
"""
|
88 |
-
def __init__(self, in_channels: int = 1, out_channels: int = 1,
|
89 |
-
periods: tp.Sequence[int] = [2, 3, 5, 7, 11], **kwargs):
|
90 |
-
super().__init__()
|
91 |
-
self.discriminators = nn.ModuleList([
|
92 |
-
PeriodDiscriminator(p, in_channels, out_channels, **kwargs) for p in periods
|
93 |
-
])
|
94 |
-
|
95 |
-
@property
|
96 |
-
def num_discriminators(self):
|
97 |
-
return len(self.discriminators)
|
98 |
-
|
99 |
-
def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
|
100 |
-
logits = []
|
101 |
-
fmaps = []
|
102 |
-
for disc in self.discriminators:
|
103 |
-
logit, fmap = disc(x)
|
104 |
-
logits.append(logit)
|
105 |
-
fmaps.append(fmap)
|
106 |
-
return logits, fmaps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/adversarial/discriminators/msd.py
DELETED
@@ -1,126 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import typing as tp
|
8 |
-
|
9 |
-
import numpy as np
|
10 |
-
import torch
|
11 |
-
import torch.nn as nn
|
12 |
-
|
13 |
-
from ...modules import NormConv1d
|
14 |
-
from .base import MultiDiscriminator, MultiDiscriminatorOutputType
|
15 |
-
|
16 |
-
|
17 |
-
class ScaleDiscriminator(nn.Module):
|
18 |
-
"""Waveform sub-discriminator.
|
19 |
-
|
20 |
-
Args:
|
21 |
-
in_channels (int): Number of input channels.
|
22 |
-
out_channels (int): Number of output channels.
|
23 |
-
kernel_sizes (Sequence[int]): Kernel sizes for first and last convolutions.
|
24 |
-
filters (int): Number of initial filters for convolutions.
|
25 |
-
max_filters (int): Maximum number of filters.
|
26 |
-
downsample_scales (Sequence[int]): Scale for downsampling implemented as strided convolutions.
|
27 |
-
inner_kernel_sizes (Sequence[int] or None): Kernel sizes for inner convolutions.
|
28 |
-
groups (Sequence[int] or None): Groups for inner convolutions.
|
29 |
-
strides (Sequence[int] or None): Strides for inner convolutions.
|
30 |
-
paddings (Sequence[int] or None): Paddings for inner convolutions.
|
31 |
-
norm (str): Normalization method.
|
32 |
-
activation (str): Activation function.
|
33 |
-
activation_params (dict): Parameters to provide to the activation function.
|
34 |
-
pad (str): Padding for initial convolution.
|
35 |
-
pad_params (dict): Parameters to provide to the padding module.
|
36 |
-
"""
|
37 |
-
def __init__(self, in_channels=1, out_channels=1, kernel_sizes: tp.Sequence[int] = [5, 3],
|
38 |
-
filters: int = 16, max_filters: int = 1024, downsample_scales: tp.Sequence[int] = [4, 4, 4, 4],
|
39 |
-
inner_kernel_sizes: tp.Optional[tp.Sequence[int]] = None, groups: tp.Optional[tp.Sequence[int]] = None,
|
40 |
-
strides: tp.Optional[tp.Sequence[int]] = None, paddings: tp.Optional[tp.Sequence[int]] = None,
|
41 |
-
norm: str = 'weight_norm', activation: str = 'LeakyReLU',
|
42 |
-
activation_params: dict = {'negative_slope': 0.2}, pad: str = 'ReflectionPad1d',
|
43 |
-
pad_params: dict = {}):
|
44 |
-
super().__init__()
|
45 |
-
assert len(kernel_sizes) == 2
|
46 |
-
assert kernel_sizes[0] % 2 == 1
|
47 |
-
assert kernel_sizes[1] % 2 == 1
|
48 |
-
assert (inner_kernel_sizes is None or len(inner_kernel_sizes) == len(downsample_scales))
|
49 |
-
assert (groups is None or len(groups) == len(downsample_scales))
|
50 |
-
assert (strides is None or len(strides) == len(downsample_scales))
|
51 |
-
assert (paddings is None or len(paddings) == len(downsample_scales))
|
52 |
-
self.activation = getattr(torch.nn, activation)(**activation_params)
|
53 |
-
self.convs = nn.ModuleList()
|
54 |
-
self.convs.append(
|
55 |
-
nn.Sequential(
|
56 |
-
getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params),
|
57 |
-
NormConv1d(in_channels, filters, kernel_size=np.prod(kernel_sizes), stride=1, norm=norm)
|
58 |
-
)
|
59 |
-
)
|
60 |
-
|
61 |
-
in_chs = filters
|
62 |
-
for i, downsample_scale in enumerate(downsample_scales):
|
63 |
-
out_chs = min(in_chs * downsample_scale, max_filters)
|
64 |
-
default_kernel_size = downsample_scale * 10 + 1
|
65 |
-
default_stride = downsample_scale
|
66 |
-
default_padding = (default_kernel_size - 1) // 2
|
67 |
-
default_groups = in_chs // 4
|
68 |
-
self.convs.append(
|
69 |
-
NormConv1d(in_chs, out_chs,
|
70 |
-
kernel_size=inner_kernel_sizes[i] if inner_kernel_sizes else default_kernel_size,
|
71 |
-
stride=strides[i] if strides else default_stride,
|
72 |
-
groups=groups[i] if groups else default_groups,
|
73 |
-
padding=paddings[i] if paddings else default_padding,
|
74 |
-
norm=norm))
|
75 |
-
in_chs = out_chs
|
76 |
-
|
77 |
-
out_chs = min(in_chs * 2, max_filters)
|
78 |
-
self.convs.append(NormConv1d(in_chs, out_chs, kernel_size=kernel_sizes[0], stride=1,
|
79 |
-
padding=(kernel_sizes[0] - 1) // 2, norm=norm))
|
80 |
-
self.conv_post = NormConv1d(out_chs, out_channels, kernel_size=kernel_sizes[1], stride=1,
|
81 |
-
padding=(kernel_sizes[1] - 1) // 2, norm=norm)
|
82 |
-
|
83 |
-
def forward(self, x: torch.Tensor):
|
84 |
-
fmap = []
|
85 |
-
for layer in self.convs:
|
86 |
-
x = layer(x)
|
87 |
-
x = self.activation(x)
|
88 |
-
fmap.append(x)
|
89 |
-
x = self.conv_post(x)
|
90 |
-
fmap.append(x)
|
91 |
-
# x = torch.flatten(x, 1, -1)
|
92 |
-
return x, fmap
|
93 |
-
|
94 |
-
|
95 |
-
class MultiScaleDiscriminator(MultiDiscriminator):
|
96 |
-
"""Multi-Scale (MSD) Discriminator,
|
97 |
-
|
98 |
-
Args:
|
99 |
-
in_channels (int): Number of input channels.
|
100 |
-
out_channels (int): Number of output channels.
|
101 |
-
downsample_factor (int): Downsampling factor between the different scales.
|
102 |
-
scale_norms (Sequence[str]): Normalization for each sub-discriminator.
|
103 |
-
**kwargs: Additional args for ScaleDiscriminator.
|
104 |
-
"""
|
105 |
-
def __init__(self, in_channels: int = 1, out_channels: int = 1, downsample_factor: int = 2,
|
106 |
-
scale_norms: tp.Sequence[str] = ['weight_norm', 'weight_norm', 'weight_norm'], **kwargs):
|
107 |
-
super().__init__()
|
108 |
-
self.discriminators = nn.ModuleList([
|
109 |
-
ScaleDiscriminator(in_channels, out_channels, norm=norm, **kwargs) for norm in scale_norms
|
110 |
-
])
|
111 |
-
self.downsample = nn.AvgPool1d(downsample_factor * 2, downsample_factor, padding=downsample_factor)
|
112 |
-
|
113 |
-
@property
|
114 |
-
def num_discriminators(self):
|
115 |
-
return len(self.discriminators)
|
116 |
-
|
117 |
-
def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
|
118 |
-
logits = []
|
119 |
-
fmaps = []
|
120 |
-
for i, disc in enumerate(self.discriminators):
|
121 |
-
if i != 0:
|
122 |
-
self.downsample(x)
|
123 |
-
logit, fmap = disc(x)
|
124 |
-
logits.append(logit)
|
125 |
-
fmaps.append(fmap)
|
126 |
-
return logits, fmaps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/adversarial/discriminators/msstftd.py
DELETED
@@ -1,134 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import typing as tp
|
8 |
-
|
9 |
-
import torchaudio
|
10 |
-
import torch
|
11 |
-
from torch import nn
|
12 |
-
from einops import rearrange
|
13 |
-
|
14 |
-
from ...modules import NormConv2d
|
15 |
-
from .base import MultiDiscriminator, MultiDiscriminatorOutputType
|
16 |
-
|
17 |
-
|
18 |
-
def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)):
|
19 |
-
return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2)
|
20 |
-
|
21 |
-
|
22 |
-
class DiscriminatorSTFT(nn.Module):
|
23 |
-
"""STFT sub-discriminator.
|
24 |
-
|
25 |
-
Args:
|
26 |
-
filters (int): Number of filters in convolutions.
|
27 |
-
in_channels (int): Number of input channels.
|
28 |
-
out_channels (int): Number of output channels.
|
29 |
-
n_fft (int): Size of FFT for each scale.
|
30 |
-
hop_length (int): Length of hop between STFT windows for each scale.
|
31 |
-
kernel_size (tuple of int): Inner Conv2d kernel sizes.
|
32 |
-
stride (tuple of int): Inner Conv2d strides.
|
33 |
-
dilations (list of int): Inner Conv2d dilation on the time dimension.
|
34 |
-
win_length (int): Window size for each scale.
|
35 |
-
normalized (bool): Whether to normalize by magnitude after stft.
|
36 |
-
norm (str): Normalization method.
|
37 |
-
activation (str): Activation function.
|
38 |
-
activation_params (dict): Parameters to provide to the activation function.
|
39 |
-
growth (int): Growth factor for the filters.
|
40 |
-
"""
|
41 |
-
def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1,
|
42 |
-
n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024,
|
43 |
-
filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4],
|
44 |
-
stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm',
|
45 |
-
activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}):
|
46 |
-
super().__init__()
|
47 |
-
assert len(kernel_size) == 2
|
48 |
-
assert len(stride) == 2
|
49 |
-
self.filters = filters
|
50 |
-
self.in_channels = in_channels
|
51 |
-
self.out_channels = out_channels
|
52 |
-
self.n_fft = n_fft
|
53 |
-
self.hop_length = hop_length
|
54 |
-
self.win_length = win_length
|
55 |
-
self.normalized = normalized
|
56 |
-
self.activation = getattr(torch.nn, activation)(**activation_params)
|
57 |
-
self.spec_transform = torchaudio.transforms.Spectrogram(
|
58 |
-
n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window,
|
59 |
-
normalized=self.normalized, center=False, pad_mode=None, power=None)
|
60 |
-
spec_channels = 2 * self.in_channels
|
61 |
-
self.convs = nn.ModuleList()
|
62 |
-
self.convs.append(
|
63 |
-
NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size))
|
64 |
-
)
|
65 |
-
in_chs = min(filters_scale * self.filters, max_filters)
|
66 |
-
for i, dilation in enumerate(dilations):
|
67 |
-
out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters)
|
68 |
-
self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride,
|
69 |
-
dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)),
|
70 |
-
norm=norm))
|
71 |
-
in_chs = out_chs
|
72 |
-
out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters)
|
73 |
-
self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]),
|
74 |
-
padding=get_2d_padding((kernel_size[0], kernel_size[0])),
|
75 |
-
norm=norm))
|
76 |
-
self.conv_post = NormConv2d(out_chs, self.out_channels,
|
77 |
-
kernel_size=(kernel_size[0], kernel_size[0]),
|
78 |
-
padding=get_2d_padding((kernel_size[0], kernel_size[0])),
|
79 |
-
norm=norm)
|
80 |
-
|
81 |
-
def forward(self, x: torch.Tensor):
|
82 |
-
fmap = []
|
83 |
-
z = self.spec_transform(x) # [B, 2, Freq, Frames, 2]
|
84 |
-
z = torch.cat([z.real, z.imag], dim=1)
|
85 |
-
z = rearrange(z, 'b c w t -> b c t w')
|
86 |
-
for i, layer in enumerate(self.convs):
|
87 |
-
z = layer(z)
|
88 |
-
z = self.activation(z)
|
89 |
-
fmap.append(z)
|
90 |
-
z = self.conv_post(z)
|
91 |
-
return z, fmap
|
92 |
-
|
93 |
-
|
94 |
-
class MultiScaleSTFTDiscriminator(MultiDiscriminator):
|
95 |
-
"""Multi-Scale STFT (MS-STFT) discriminator.
|
96 |
-
|
97 |
-
Args:
|
98 |
-
filters (int): Number of filters in convolutions.
|
99 |
-
in_channels (int): Number of input channels.
|
100 |
-
out_channels (int): Number of output channels.
|
101 |
-
sep_channels (bool): Separate channels to distinct samples for stereo support.
|
102 |
-
n_ffts (Sequence[int]): Size of FFT for each scale.
|
103 |
-
hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale.
|
104 |
-
win_lengths (Sequence[int]): Window size for each scale.
|
105 |
-
**kwargs: Additional args for STFTDiscriminator.
|
106 |
-
"""
|
107 |
-
def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, sep_channels: bool = False,
|
108 |
-
n_ffts: tp.List[int] = [1024, 2048, 512], hop_lengths: tp.List[int] = [256, 512, 128],
|
109 |
-
win_lengths: tp.List[int] = [1024, 2048, 512], **kwargs):
|
110 |
-
super().__init__()
|
111 |
-
assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
|
112 |
-
self.sep_channels = sep_channels
|
113 |
-
self.discriminators = nn.ModuleList([
|
114 |
-
DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels,
|
115 |
-
n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs)
|
116 |
-
for i in range(len(n_ffts))
|
117 |
-
])
|
118 |
-
|
119 |
-
@property
|
120 |
-
def num_discriminators(self):
|
121 |
-
return len(self.discriminators)
|
122 |
-
|
123 |
-
def _separate_channels(self, x: torch.Tensor) -> torch.Tensor:
|
124 |
-
B, C, T = x.shape
|
125 |
-
return x.view(-1, 1, T)
|
126 |
-
|
127 |
-
def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
|
128 |
-
logits = []
|
129 |
-
fmaps = []
|
130 |
-
for disc in self.discriminators:
|
131 |
-
logit, fmap = disc(x)
|
132 |
-
logits.append(logit)
|
133 |
-
fmaps.append(fmap)
|
134 |
-
return logits, fmaps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/adversarial/losses.py
DELETED
@@ -1,228 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
"""
|
8 |
-
Utility module to handle adversarial losses without requiring to mess up the main training loop.
|
9 |
-
"""
|
10 |
-
|
11 |
-
import typing as tp
|
12 |
-
|
13 |
-
import flashy
|
14 |
-
import torch
|
15 |
-
import torch.nn as nn
|
16 |
-
import torch.nn.functional as F
|
17 |
-
|
18 |
-
|
19 |
-
ADVERSARIAL_LOSSES = ['mse', 'hinge', 'hinge2']
|
20 |
-
|
21 |
-
|
22 |
-
AdvLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor], torch.Tensor]]
|
23 |
-
FeatLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]]
|
24 |
-
|
25 |
-
|
26 |
-
class AdversarialLoss(nn.Module):
|
27 |
-
"""Adversary training wrapper.
|
28 |
-
|
29 |
-
Args:
|
30 |
-
adversary (nn.Module): The adversary module will be used to estimate the logits given the fake and real samples.
|
31 |
-
We assume here the adversary output is ``Tuple[List[torch.Tensor], List[List[torch.Tensor]]]``
|
32 |
-
where the first item is a list of logits and the second item is a list of feature maps.
|
33 |
-
optimizer (torch.optim.Optimizer): Optimizer used for training the given module.
|
34 |
-
loss (AdvLossType): Loss function for generator training.
|
35 |
-
loss_real (AdvLossType): Loss function for adversarial training on logits from real samples.
|
36 |
-
loss_fake (AdvLossType): Loss function for adversarial training on logits from fake samples.
|
37 |
-
loss_feat (FeatLossType): Feature matching loss function for generator training.
|
38 |
-
normalize (bool): Whether to normalize by number of sub-discriminators.
|
39 |
-
|
40 |
-
Example of usage:
|
41 |
-
adv_loss = AdversarialLoss(adversaries, optimizer, loss, loss_real, loss_fake)
|
42 |
-
for real in loader:
|
43 |
-
noise = torch.randn(...)
|
44 |
-
fake = model(noise)
|
45 |
-
adv_loss.train_adv(fake, real)
|
46 |
-
loss, _ = adv_loss(fake, real)
|
47 |
-
loss.backward()
|
48 |
-
"""
|
49 |
-
def __init__(self,
|
50 |
-
adversary: nn.Module,
|
51 |
-
optimizer: torch.optim.Optimizer,
|
52 |
-
loss: AdvLossType,
|
53 |
-
loss_real: AdvLossType,
|
54 |
-
loss_fake: AdvLossType,
|
55 |
-
loss_feat: tp.Optional[FeatLossType] = None,
|
56 |
-
normalize: bool = True):
|
57 |
-
super().__init__()
|
58 |
-
self.adversary: nn.Module = adversary
|
59 |
-
flashy.distrib.broadcast_model(self.adversary)
|
60 |
-
self.optimizer = optimizer
|
61 |
-
self.loss = loss
|
62 |
-
self.loss_real = loss_real
|
63 |
-
self.loss_fake = loss_fake
|
64 |
-
self.loss_feat = loss_feat
|
65 |
-
self.normalize = normalize
|
66 |
-
|
67 |
-
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
68 |
-
# Add the optimizer state dict inside our own.
|
69 |
-
super()._save_to_state_dict(destination, prefix, keep_vars)
|
70 |
-
destination[prefix + 'optimizer'] = self.optimizer.state_dict()
|
71 |
-
return destination
|
72 |
-
|
73 |
-
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
74 |
-
# Load optimizer state.
|
75 |
-
self.optimizer.load_state_dict(state_dict.pop(prefix + 'optimizer'))
|
76 |
-
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
77 |
-
|
78 |
-
def get_adversary_pred(self, x):
|
79 |
-
"""Run adversary model, validating expected output format."""
|
80 |
-
logits, fmaps = self.adversary(x)
|
81 |
-
assert isinstance(logits, list) and all([isinstance(t, torch.Tensor) for t in logits]), \
|
82 |
-
f'Expecting a list of tensors as logits but {type(logits)} found.'
|
83 |
-
assert isinstance(fmaps, list), f'Expecting a list of features maps but {type(fmaps)} found.'
|
84 |
-
for fmap in fmaps:
|
85 |
-
assert isinstance(fmap, list) and all([isinstance(f, torch.Tensor) for f in fmap]), \
|
86 |
-
f'Expecting a list of tensors as feature maps but {type(fmap)} found.'
|
87 |
-
return logits, fmaps
|
88 |
-
|
89 |
-
def train_adv(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
|
90 |
-
"""Train the adversary with the given fake and real example.
|
91 |
-
|
92 |
-
We assume the adversary output is the following format: Tuple[List[torch.Tensor], List[List[torch.Tensor]]].
|
93 |
-
The first item being the logits and second item being a list of feature maps for each sub-discriminator.
|
94 |
-
|
95 |
-
This will automatically synchronize gradients (with `flashy.distrib.eager_sync_model`)
|
96 |
-
and call the optimizer.
|
97 |
-
"""
|
98 |
-
loss = torch.tensor(0., device=fake.device)
|
99 |
-
all_logits_fake_is_fake, _ = self.get_adversary_pred(fake.detach())
|
100 |
-
all_logits_real_is_fake, _ = self.get_adversary_pred(real.detach())
|
101 |
-
n_sub_adversaries = len(all_logits_fake_is_fake)
|
102 |
-
for logit_fake_is_fake, logit_real_is_fake in zip(all_logits_fake_is_fake, all_logits_real_is_fake):
|
103 |
-
loss += self.loss_fake(logit_fake_is_fake) + self.loss_real(logit_real_is_fake)
|
104 |
-
|
105 |
-
if self.normalize:
|
106 |
-
loss /= n_sub_adversaries
|
107 |
-
|
108 |
-
self.optimizer.zero_grad()
|
109 |
-
with flashy.distrib.eager_sync_model(self.adversary):
|
110 |
-
loss.backward()
|
111 |
-
self.optimizer.step()
|
112 |
-
|
113 |
-
return loss
|
114 |
-
|
115 |
-
def forward(self, fake: torch.Tensor, real: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
116 |
-
"""Return the loss for the generator, i.e. trying to fool the adversary,
|
117 |
-
and feature matching loss if provided.
|
118 |
-
"""
|
119 |
-
adv = torch.tensor(0., device=fake.device)
|
120 |
-
feat = torch.tensor(0., device=fake.device)
|
121 |
-
with flashy.utils.readonly(self.adversary):
|
122 |
-
all_logits_fake_is_fake, all_fmap_fake = self.get_adversary_pred(fake)
|
123 |
-
all_logits_real_is_fake, all_fmap_real = self.get_adversary_pred(real)
|
124 |
-
n_sub_adversaries = len(all_logits_fake_is_fake)
|
125 |
-
for logit_fake_is_fake in all_logits_fake_is_fake:
|
126 |
-
adv += self.loss(logit_fake_is_fake)
|
127 |
-
if self.loss_feat:
|
128 |
-
for fmap_fake, fmap_real in zip(all_fmap_fake, all_fmap_real):
|
129 |
-
feat += self.loss_feat(fmap_fake, fmap_real)
|
130 |
-
|
131 |
-
if self.normalize:
|
132 |
-
adv /= n_sub_adversaries
|
133 |
-
feat /= n_sub_adversaries
|
134 |
-
|
135 |
-
return adv, feat
|
136 |
-
|
137 |
-
|
138 |
-
def get_adv_criterion(loss_type: str) -> tp.Callable:
|
139 |
-
assert loss_type in ADVERSARIAL_LOSSES
|
140 |
-
if loss_type == 'mse':
|
141 |
-
return mse_loss
|
142 |
-
elif loss_type == 'hinge':
|
143 |
-
return hinge_loss
|
144 |
-
elif loss_type == 'hinge2':
|
145 |
-
return hinge2_loss
|
146 |
-
raise ValueError('Unsupported loss')
|
147 |
-
|
148 |
-
|
149 |
-
def get_fake_criterion(loss_type: str) -> tp.Callable:
|
150 |
-
assert loss_type in ADVERSARIAL_LOSSES
|
151 |
-
if loss_type == 'mse':
|
152 |
-
return mse_fake_loss
|
153 |
-
elif loss_type in ['hinge', 'hinge2']:
|
154 |
-
return hinge_fake_loss
|
155 |
-
raise ValueError('Unsupported loss')
|
156 |
-
|
157 |
-
|
158 |
-
def get_real_criterion(loss_type: str) -> tp.Callable:
|
159 |
-
assert loss_type in ADVERSARIAL_LOSSES
|
160 |
-
if loss_type == 'mse':
|
161 |
-
return mse_real_loss
|
162 |
-
elif loss_type in ['hinge', 'hinge2']:
|
163 |
-
return hinge_real_loss
|
164 |
-
raise ValueError('Unsupported loss')
|
165 |
-
|
166 |
-
|
167 |
-
def mse_real_loss(x: torch.Tensor) -> torch.Tensor:
|
168 |
-
return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
|
169 |
-
|
170 |
-
|
171 |
-
def mse_fake_loss(x: torch.Tensor) -> torch.Tensor:
|
172 |
-
return F.mse_loss(x, torch.tensor(0., device=x.device).expand_as(x))
|
173 |
-
|
174 |
-
|
175 |
-
def hinge_real_loss(x: torch.Tensor) -> torch.Tensor:
|
176 |
-
return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
|
177 |
-
|
178 |
-
|
179 |
-
def hinge_fake_loss(x: torch.Tensor) -> torch.Tensor:
|
180 |
-
return -torch.mean(torch.min(-x - 1, torch.tensor(0., device=x.device).expand_as(x)))
|
181 |
-
|
182 |
-
|
183 |
-
def mse_loss(x: torch.Tensor) -> torch.Tensor:
|
184 |
-
if x.numel() == 0:
|
185 |
-
return torch.tensor([0.0], device=x.device)
|
186 |
-
return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
|
187 |
-
|
188 |
-
|
189 |
-
def hinge_loss(x: torch.Tensor) -> torch.Tensor:
|
190 |
-
if x.numel() == 0:
|
191 |
-
return torch.tensor([0.0], device=x.device)
|
192 |
-
return -x.mean()
|
193 |
-
|
194 |
-
|
195 |
-
def hinge2_loss(x: torch.Tensor) -> torch.Tensor:
|
196 |
-
if x.numel() == 0:
|
197 |
-
return torch.tensor([0.0])
|
198 |
-
return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
|
199 |
-
|
200 |
-
|
201 |
-
class FeatureMatchingLoss(nn.Module):
|
202 |
-
"""Feature matching loss for adversarial training.
|
203 |
-
|
204 |
-
Args:
|
205 |
-
loss (nn.Module): Loss to use for feature matching (default=torch.nn.L1).
|
206 |
-
normalize (bool): Whether to normalize the loss.
|
207 |
-
by number of feature maps.
|
208 |
-
"""
|
209 |
-
def __init__(self, loss: nn.Module = torch.nn.L1Loss(), normalize: bool = True):
|
210 |
-
super().__init__()
|
211 |
-
self.loss = loss
|
212 |
-
self.normalize = normalize
|
213 |
-
|
214 |
-
def forward(self, fmap_fake: tp.List[torch.Tensor], fmap_real: tp.List[torch.Tensor]) -> torch.Tensor:
|
215 |
-
assert len(fmap_fake) == len(fmap_real) and len(fmap_fake) > 0
|
216 |
-
feat_loss = torch.tensor(0., device=fmap_fake[0].device)
|
217 |
-
feat_scale = torch.tensor(0., device=fmap_fake[0].device)
|
218 |
-
n_fmaps = 0
|
219 |
-
for (feat_fake, feat_real) in zip(fmap_fake, fmap_real):
|
220 |
-
assert feat_fake.shape == feat_real.shape
|
221 |
-
n_fmaps += 1
|
222 |
-
feat_loss += self.loss(feat_fake, feat_real)
|
223 |
-
feat_scale += torch.mean(torch.abs(feat_real))
|
224 |
-
|
225 |
-
if self.normalize:
|
226 |
-
feat_loss /= n_fmaps
|
227 |
-
|
228 |
-
return feat_loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/data/__init__.py
CHANGED
@@ -3,8 +3,6 @@
|
|
3 |
#
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
-
"""Audio loading and writing support. Datasets for raw audio
|
7 |
-
or also including some metadata."""
|
8 |
|
9 |
# flake8: noqa
|
10 |
-
from . import audio, audio_dataset
|
|
|
3 |
#
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
6 |
|
7 |
# flake8: noqa
|
8 |
+
from . import audio, audio_dataset
|
audiocraft/data/audio.py
CHANGED
@@ -18,11 +18,11 @@ import numpy as np
|
|
18 |
import soundfile
|
19 |
import torch
|
20 |
from torch.nn import functional as F
|
|
|
21 |
|
22 |
import av
|
23 |
-
import subprocess as sp
|
24 |
|
25 |
-
from .audio_utils import f32_pcm, normalize_audio
|
26 |
|
27 |
|
28 |
_av_initialized = False
|
@@ -78,7 +78,7 @@ def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: floa
|
|
78 |
seek_time (float): Time at which to start reading in the file.
|
79 |
duration (float): Duration to read from the file. If set to -1, the whole file is read.
|
80 |
Returns:
|
81 |
-
|
82 |
"""
|
83 |
_init_av()
|
84 |
with av.open(str(filepath)) as af:
|
@@ -123,7 +123,7 @@ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
|
|
123 |
duration (float): Duration to read from the file. If set to -1, the whole file is read.
|
124 |
pad (bool): Pad output audio if not reaching expected duration.
|
125 |
Returns:
|
126 |
-
|
127 |
"""
|
128 |
fp = Path(filepath)
|
129 |
if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg
|
@@ -136,6 +136,12 @@ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
|
|
136 |
wav = torch.from_numpy(wav).t().contiguous()
|
137 |
if len(wav.shape) == 1:
|
138 |
wav = torch.unsqueeze(wav, 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
else:
|
140 |
wav, sr = _av_read(filepath, seek_time, duration)
|
141 |
if pad and duration > 0:
|
@@ -144,35 +150,19 @@ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
|
|
144 |
return wav, sr
|
145 |
|
146 |
|
147 |
-
def _piping_to_ffmpeg(out_path: tp.Union[str, Path], wav: torch.Tensor, sample_rate: int, flags: tp.List[str]):
|
148 |
-
# ffmpeg is always installed and torchaudio is a bit unstable lately, so let's bypass it entirely.
|
149 |
-
assert wav.dim() == 2, wav.shape
|
150 |
-
command = [
|
151 |
-
'ffmpeg',
|
152 |
-
'-loglevel', 'error',
|
153 |
-
'-y', '-f', 'f32le', '-ar', str(sample_rate), '-ac', str(wav.shape[0]),
|
154 |
-
'-i', '-'] + flags + [str(out_path)]
|
155 |
-
input_ = f32_pcm(wav).t().detach().cpu().numpy().tobytes()
|
156 |
-
sp.run(command, input=input_, check=True)
|
157 |
-
|
158 |
-
|
159 |
def audio_write(stem_name: tp.Union[str, Path],
|
160 |
wav: torch.Tensor, sample_rate: int,
|
161 |
-
format: str = 'wav', mp3_rate: int = 320,
|
162 |
-
|
163 |
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
|
164 |
-
loudness_compressor: bool = False,
|
165 |
log_clipping: bool = True, make_parent_dir: bool = True,
|
166 |
add_suffix: bool = True) -> Path:
|
167 |
"""Convenience function for saving audio to disk. Returns the filename the audio was written to.
|
168 |
|
169 |
Args:
|
170 |
stem_name (str or Path): Filename without extension which will be added automatically.
|
171 |
-
|
172 |
-
sample_rate (int): Sample rate of audio data.
|
173 |
-
format (str): Either "wav", "mp3", "ogg", or "flac".
|
174 |
mp3_rate (int): kbps when using mp3s.
|
175 |
-
ogg_rate (int): kbps when using ogg/vorbis. If not provided, let ffmpeg decide for itself.
|
176 |
normalize (bool): if `True` (default), normalizes according to the prescribed
|
177 |
strategy (see after). If `False`, the strategy is only used in case clipping
|
178 |
would happen.
|
@@ -183,8 +173,7 @@ def audio_write(stem_name: tp.Union[str, Path],
|
|
183 |
rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
|
184 |
than the `peak_clip` one to avoid further clipping.
|
185 |
loudness_headroom_db (float): Target loudness for loudness normalization.
|
186 |
-
|
187 |
-
when strategy is 'loudness' log_clipping (bool): If True, basic logging on stderr when clipping still
|
188 |
occurs despite strategy (only for 'rms').
|
189 |
make_parent_dir (bool): Make parent directory if it doesn't exist.
|
190 |
Returns:
|
@@ -197,23 +186,16 @@ def audio_write(stem_name: tp.Union[str, Path],
|
|
197 |
raise ValueError("Input wav should be at most 2 dimension.")
|
198 |
assert wav.isfinite().all()
|
199 |
wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
|
200 |
-
rms_headroom_db, loudness_headroom_db,
|
201 |
-
|
202 |
-
|
203 |
if format == 'mp3':
|
204 |
suffix = '.mp3'
|
205 |
-
|
206 |
elif format == 'wav':
|
|
|
207 |
suffix = '.wav'
|
208 |
-
|
209 |
-
elif format == 'ogg':
|
210 |
-
suffix = '.ogg'
|
211 |
-
flags = ['-f', 'ogg', '-c:a', 'libvorbis']
|
212 |
-
if ogg_rate is not None:
|
213 |
-
flags += ['-b:a', f'{ogg_rate}k']
|
214 |
-
elif format == 'flac':
|
215 |
-
suffix = '.flac'
|
216 |
-
flags = ['-f', 'flac']
|
217 |
else:
|
218 |
raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
|
219 |
if not add_suffix:
|
@@ -222,7 +204,7 @@ def audio_write(stem_name: tp.Union[str, Path],
|
|
222 |
if make_parent_dir:
|
223 |
path.parent.mkdir(exist_ok=True, parents=True)
|
224 |
try:
|
225 |
-
|
226 |
except Exception:
|
227 |
if path.exists():
|
228 |
# we do not want to leave half written files around.
|
|
|
18 |
import soundfile
|
19 |
import torch
|
20 |
from torch.nn import functional as F
|
21 |
+
import torchaudio as ta
|
22 |
|
23 |
import av
|
|
|
24 |
|
25 |
+
from .audio_utils import f32_pcm, i16_pcm, normalize_audio
|
26 |
|
27 |
|
28 |
_av_initialized = False
|
|
|
78 |
seek_time (float): Time at which to start reading in the file.
|
79 |
duration (float): Duration to read from the file. If set to -1, the whole file is read.
|
80 |
Returns:
|
81 |
+
Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate
|
82 |
"""
|
83 |
_init_av()
|
84 |
with av.open(str(filepath)) as af:
|
|
|
123 |
duration (float): Duration to read from the file. If set to -1, the whole file is read.
|
124 |
pad (bool): Pad output audio if not reaching expected duration.
|
125 |
Returns:
|
126 |
+
Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate.
|
127 |
"""
|
128 |
fp = Path(filepath)
|
129 |
if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg
|
|
|
136 |
wav = torch.from_numpy(wav).t().contiguous()
|
137 |
if len(wav.shape) == 1:
|
138 |
wav = torch.unsqueeze(wav, 0)
|
139 |
+
elif (
|
140 |
+
fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats()
|
141 |
+
and duration <= 0 and seek_time == 0
|
142 |
+
):
|
143 |
+
# Torchaudio is faster if we load an entire file at once.
|
144 |
+
wav, sr = ta.load(fp)
|
145 |
else:
|
146 |
wav, sr = _av_read(filepath, seek_time, duration)
|
147 |
if pad and duration > 0:
|
|
|
150 |
return wav, sr
|
151 |
|
152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
def audio_write(stem_name: tp.Union[str, Path],
|
154 |
wav: torch.Tensor, sample_rate: int,
|
155 |
+
format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
|
156 |
+
strategy: str = 'peak', peak_clip_headroom_db: float = 1,
|
157 |
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
|
|
|
158 |
log_clipping: bool = True, make_parent_dir: bool = True,
|
159 |
add_suffix: bool = True) -> Path:
|
160 |
"""Convenience function for saving audio to disk. Returns the filename the audio was written to.
|
161 |
|
162 |
Args:
|
163 |
stem_name (str or Path): Filename without extension which will be added automatically.
|
164 |
+
format (str): Either "wav" or "mp3".
|
|
|
|
|
165 |
mp3_rate (int): kbps when using mp3s.
|
|
|
166 |
normalize (bool): if `True` (default), normalizes according to the prescribed
|
167 |
strategy (see after). If `False`, the strategy is only used in case clipping
|
168 |
would happen.
|
|
|
173 |
rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
|
174 |
than the `peak_clip` one to avoid further clipping.
|
175 |
loudness_headroom_db (float): Target loudness for loudness normalization.
|
176 |
+
log_clipping (bool): If True, basic logging on stderr when clipping still
|
|
|
177 |
occurs despite strategy (only for 'rms').
|
178 |
make_parent_dir (bool): Make parent directory if it doesn't exist.
|
179 |
Returns:
|
|
|
186 |
raise ValueError("Input wav should be at most 2 dimension.")
|
187 |
assert wav.isfinite().all()
|
188 |
wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
|
189 |
+
rms_headroom_db, loudness_headroom_db, log_clipping=log_clipping,
|
190 |
+
sample_rate=sample_rate, stem_name=str(stem_name))
|
191 |
+
kwargs: dict = {}
|
192 |
if format == 'mp3':
|
193 |
suffix = '.mp3'
|
194 |
+
kwargs.update({"compression": mp3_rate})
|
195 |
elif format == 'wav':
|
196 |
+
wav = i16_pcm(wav)
|
197 |
suffix = '.wav'
|
198 |
+
kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
else:
|
200 |
raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
|
201 |
if not add_suffix:
|
|
|
204 |
if make_parent_dir:
|
205 |
path.parent.mkdir(exist_ok=True, parents=True)
|
206 |
try:
|
207 |
+
ta.save(path, wav, sample_rate, **kwargs)
|
208 |
except Exception:
|
209 |
if path.exists():
|
210 |
# we do not want to leave half written files around.
|
audiocraft/data/audio_dataset.py
CHANGED
@@ -3,16 +3,12 @@
|
|
3 |
#
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
without having to scan again the folders, we precompute some metadata
|
8 |
-
(filename, sample rate, duration), and use that to efficiently sample audio segments.
|
9 |
-
"""
|
10 |
import argparse
|
11 |
import copy
|
12 |
from concurrent.futures import ThreadPoolExecutor, Future
|
13 |
from dataclasses import dataclass, fields
|
14 |
from contextlib import ExitStack
|
15 |
-
from functools import lru_cache
|
16 |
import gzip
|
17 |
import json
|
18 |
import logging
|
@@ -85,12 +81,9 @@ class AudioMeta(BaseInfo):
|
|
85 |
class SegmentInfo(BaseInfo):
|
86 |
meta: AudioMeta
|
87 |
seek_time: float
|
88 |
-
#
|
89 |
-
# at the target sample rate and target number of channels.
|
90 |
-
n_frames: int # actual number of frames without padding
|
91 |
total_frames: int # total number of frames, padding included
|
92 |
-
sample_rate: int
|
93 |
-
channels: int # number of audio channels.
|
94 |
|
95 |
|
96 |
DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
|
@@ -121,8 +114,8 @@ def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta:
|
|
121 |
|
122 |
Args:
|
123 |
m (AudioMeta): Audio meta to resolve.
|
124 |
-
fast (bool): If True, uses a really fast check for determining if a file
|
125 |
-
|
126 |
Returns:
|
127 |
AudioMeta: Audio meta with resolved path.
|
128 |
"""
|
@@ -158,7 +151,7 @@ def find_audio_files(path: tp.Union[Path, str],
|
|
158 |
progress (bool): Whether to log progress on audio files collection.
|
159 |
workers (int): number of parallel workers, if 0, use only the current thread.
|
160 |
Returns:
|
161 |
-
|
162 |
"""
|
163 |
audio_files = []
|
164 |
futures: tp.List[Future] = []
|
@@ -210,7 +203,7 @@ def load_audio_meta(path: tp.Union[str, Path],
|
|
210 |
resolve (bool): Whether to resolve the path from AudioMeta (default=True).
|
211 |
fast (bool): activates some tricks to make things faster.
|
212 |
Returns:
|
213 |
-
|
214 |
"""
|
215 |
open_fn = gzip.open if str(path).lower().endswith('.gz') else open
|
216 |
with open_fn(path, 'rb') as fp: # type: ignore
|
@@ -257,14 +250,9 @@ class AudioDataset:
|
|
257 |
allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
|
258 |
original audio meta.
|
259 |
|
260 |
-
Note that you can call `start_epoch(epoch)` in order to get
|
261 |
-
a deterministic "randomization" for `shuffle=True`.
|
262 |
-
For a given epoch and dataset index, this will always return the same extract.
|
263 |
-
You can get back some diversity by setting the `shuffle_seed` param.
|
264 |
-
|
265 |
Args:
|
266 |
-
meta (
|
267 |
-
segment_duration (float
|
268 |
If not specified, the dataset will load the full audio segment from the file.
|
269 |
shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
|
270 |
sample_rate (int): Target sample rate of the loaded audio samples.
|
@@ -278,19 +266,10 @@ class AudioDataset:
|
|
278 |
is shorter than the desired segment.
|
279 |
max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
|
280 |
return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
|
281 |
-
min_audio_duration (float, optional): Minimum audio file duration, in seconds, if provided
|
282 |
audio shorter than this will be filtered out.
|
283 |
-
max_audio_duration (float, optional): Maximal audio file duration in seconds, if provided
|
284 |
audio longer than this will be filtered out.
|
285 |
-
shuffle_seed (int): can be used to further randomize
|
286 |
-
load_wav (bool): if False, skip loading the wav but returns a tensor of 0
|
287 |
-
with the expected segment_duration (which must be provided if load_wav is False).
|
288 |
-
permutation_on_files (bool): only if `sample_on_weight` and `sample_on_duration`
|
289 |
-
are False. Will ensure a permutation on files when going through the dataset.
|
290 |
-
In that case the epoch number must be provided in order for the model
|
291 |
-
to continue the permutation across epochs. In that case, it is assumed
|
292 |
-
that `num_samples = total_batch_size * num_updates_per_epoch`, with
|
293 |
-
`total_batch_size` the overall batch size accounting for all gpus.
|
294 |
"""
|
295 |
def __init__(self,
|
296 |
meta: tp.List[AudioMeta],
|
@@ -306,14 +285,16 @@ class AudioDataset:
|
|
306 |
max_read_retry: int = 10,
|
307 |
return_info: bool = False,
|
308 |
min_audio_duration: tp.Optional[float] = None,
|
309 |
-
max_audio_duration: tp.Optional[float] = None
|
310 |
-
shuffle_seed: int = 0,
|
311 |
-
load_wav: bool = True,
|
312 |
-
permutation_on_files: bool = False,
|
313 |
):
|
314 |
-
assert len(meta) > 0,
|
315 |
assert segment_duration is None or segment_duration > 0
|
316 |
assert segment_duration is None or min_segment_ratio >= 0
|
|
|
|
|
|
|
|
|
|
|
317 |
self.segment_duration = segment_duration
|
318 |
self.min_segment_ratio = min_segment_ratio
|
319 |
self.max_audio_duration = max_audio_duration
|
@@ -336,25 +317,13 @@ class AudioDataset:
|
|
336 |
self.sampling_probabilities = self._get_sampling_probabilities()
|
337 |
self.max_read_retry = max_read_retry
|
338 |
self.return_info = return_info
|
339 |
-
self.shuffle_seed = shuffle_seed
|
340 |
-
self.current_epoch: tp.Optional[int] = None
|
341 |
-
self.load_wav = load_wav
|
342 |
-
if not load_wav:
|
343 |
-
assert segment_duration is not None
|
344 |
-
self.permutation_on_files = permutation_on_files
|
345 |
-
if permutation_on_files:
|
346 |
-
assert not self.sample_on_duration
|
347 |
-
assert not self.sample_on_weight
|
348 |
-
assert self.shuffle
|
349 |
-
|
350 |
-
def start_epoch(self, epoch: int):
|
351 |
-
self.current_epoch = epoch
|
352 |
|
353 |
def __len__(self):
|
354 |
return self.num_samples
|
355 |
|
356 |
def _get_sampling_probabilities(self, normalized: bool = True):
|
357 |
-
"""Return the sampling probabilities for each file inside `self.meta`.
|
|
|
358 |
scores: tp.List[float] = []
|
359 |
for file_meta in self.meta:
|
360 |
score = 1.
|
@@ -368,32 +337,12 @@ class AudioDataset:
|
|
368 |
probabilities /= probabilities.sum()
|
369 |
return probabilities
|
370 |
|
371 |
-
|
372 |
-
|
373 |
-
def _get_file_permutation(num_files: int, permutation_index: int, base_seed: int):
|
374 |
-
# Used to keep the most recent files permutation in memory implicitely.
|
375 |
-
# will work unless someone is using a lot of Datasets in parallel.
|
376 |
-
rng = torch.Generator()
|
377 |
-
rng.manual_seed(base_seed + permutation_index)
|
378 |
-
return torch.randperm(num_files, generator=rng)
|
379 |
-
|
380 |
-
def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta:
|
381 |
-
"""Sample a given file from `self.meta`. Can be overridden in subclasses.
|
382 |
This is only called if `segment_duration` is not None.
|
383 |
|
384 |
You must use the provided random number generator `rng` for reproducibility.
|
385 |
-
You can further make use of the index accessed.
|
386 |
"""
|
387 |
-
if self.permutation_on_files:
|
388 |
-
assert self.current_epoch is not None
|
389 |
-
total_index = self.current_epoch * len(self) + index
|
390 |
-
permutation_index = total_index // len(self.meta)
|
391 |
-
relative_index = total_index % len(self.meta)
|
392 |
-
permutation = AudioDataset._get_file_permutation(
|
393 |
-
len(self.meta), permutation_index, self.shuffle_seed)
|
394 |
-
file_index = permutation[relative_index]
|
395 |
-
return self.meta[file_index]
|
396 |
-
|
397 |
if not self.sample_on_weight and not self.sample_on_duration:
|
398 |
file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
|
399 |
else:
|
@@ -401,15 +350,6 @@ class AudioDataset:
|
|
401 |
|
402 |
return self.meta[file_index]
|
403 |
|
404 |
-
def _audio_read(self, path: str, seek_time: float = 0, duration: float = -1):
|
405 |
-
# Override this method in subclass if needed.
|
406 |
-
if self.load_wav:
|
407 |
-
return audio_read(path, seek_time, duration, pad=False)
|
408 |
-
else:
|
409 |
-
assert self.segment_duration is not None
|
410 |
-
n_frames = int(self.sample_rate * self.segment_duration)
|
411 |
-
return torch.zeros(self.channels, n_frames), self.sample_rate
|
412 |
-
|
413 |
def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
|
414 |
if self.segment_duration is None:
|
415 |
file_meta = self.meta[index]
|
@@ -417,22 +357,18 @@ class AudioDataset:
|
|
417 |
out = convert_audio(out, sr, self.sample_rate, self.channels)
|
418 |
n_frames = out.shape[-1]
|
419 |
segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
|
420 |
-
sample_rate=self.sample_rate
|
421 |
else:
|
422 |
rng = torch.Generator()
|
423 |
if self.shuffle:
|
424 |
-
# We use index, plus extra randomness
|
425 |
-
|
426 |
-
if self.current_epoch is None:
|
427 |
-
rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
|
428 |
-
else:
|
429 |
-
rng.manual_seed(index + self.num_samples * (self.current_epoch + self.shuffle_seed))
|
430 |
else:
|
431 |
# We only use index
|
432 |
rng.manual_seed(index)
|
433 |
|
434 |
for retry in range(self.max_read_retry):
|
435 |
-
file_meta = self.sample_file(
|
436 |
# We add some variance in the file position even if audio file is smaller than segment
|
437 |
# without ending up with empty segments
|
438 |
max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
|
@@ -445,7 +381,7 @@ class AudioDataset:
|
|
445 |
if self.pad:
|
446 |
out = F.pad(out, (0, target_frames - n_frames))
|
447 |
segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
|
448 |
-
sample_rate=self.sample_rate
|
449 |
except Exception as exc:
|
450 |
logger.warning("Error opening file %s: %r", file_meta.path, exc)
|
451 |
if retry == self.max_read_retry - 1:
|
@@ -487,7 +423,7 @@ class AudioDataset:
|
|
487 |
if to_pad:
|
488 |
# Each wav could be of a different duration as they are not segmented.
|
489 |
for i in range(len(samples)):
|
490 |
-
# Determines the total
|
491 |
segment_infos[i].total_frames = max_len
|
492 |
wavs[i] = _pad_wav(wavs[i])
|
493 |
|
@@ -500,7 +436,9 @@ class AudioDataset:
|
|
500 |
return torch.stack(samples)
|
501 |
|
502 |
def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
|
503 |
-
"""Filters out audio files with
|
|
|
|
|
504 |
orig_len = len(meta)
|
505 |
|
506 |
# Filter data that is too short.
|
|
|
3 |
#
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
|
|
|
|
|
|
7 |
import argparse
|
8 |
import copy
|
9 |
from concurrent.futures import ThreadPoolExecutor, Future
|
10 |
from dataclasses import dataclass, fields
|
11 |
from contextlib import ExitStack
|
|
|
12 |
import gzip
|
13 |
import json
|
14 |
import logging
|
|
|
81 |
class SegmentInfo(BaseInfo):
|
82 |
meta: AudioMeta
|
83 |
seek_time: float
|
84 |
+
n_frames: int # actual number of frames without padding
|
|
|
|
|
85 |
total_frames: int # total number of frames, padding included
|
86 |
+
sample_rate: int # actual sample rate
|
|
|
87 |
|
88 |
|
89 |
DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
|
|
|
114 |
|
115 |
Args:
|
116 |
m (AudioMeta): Audio meta to resolve.
|
117 |
+
fast (bool): If True, uses a really fast check for determining if a file is already absolute or not.
|
118 |
+
Only valid on Linux/Mac.
|
119 |
Returns:
|
120 |
AudioMeta: Audio meta with resolved path.
|
121 |
"""
|
|
|
151 |
progress (bool): Whether to log progress on audio files collection.
|
152 |
workers (int): number of parallel workers, if 0, use only the current thread.
|
153 |
Returns:
|
154 |
+
List[AudioMeta]: List of audio file path and its metadata.
|
155 |
"""
|
156 |
audio_files = []
|
157 |
futures: tp.List[Future] = []
|
|
|
203 |
resolve (bool): Whether to resolve the path from AudioMeta (default=True).
|
204 |
fast (bool): activates some tricks to make things faster.
|
205 |
Returns:
|
206 |
+
List[AudioMeta]: List of audio file path and its total duration.
|
207 |
"""
|
208 |
open_fn = gzip.open if str(path).lower().endswith('.gz') else open
|
209 |
with open_fn(path, 'rb') as fp: # type: ignore
|
|
|
250 |
allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
|
251 |
original audio meta.
|
252 |
|
|
|
|
|
|
|
|
|
|
|
253 |
Args:
|
254 |
+
meta (tp.List[AudioMeta]): List of audio files metadata.
|
255 |
+
segment_duration (float): Optional segment duration of audio to load.
|
256 |
If not specified, the dataset will load the full audio segment from the file.
|
257 |
shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
|
258 |
sample_rate (int): Target sample rate of the loaded audio samples.
|
|
|
266 |
is shorter than the desired segment.
|
267 |
max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
|
268 |
return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
|
269 |
+
min_audio_duration (tp.Optional[float], optional): Minimum audio file duration, in seconds, if provided
|
270 |
audio shorter than this will be filtered out.
|
271 |
+
max_audio_duration (tp.Optional[float], optional): Maximal audio file duration in seconds, if provided
|
272 |
audio longer than this will be filtered out.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
"""
|
274 |
def __init__(self,
|
275 |
meta: tp.List[AudioMeta],
|
|
|
285 |
max_read_retry: int = 10,
|
286 |
return_info: bool = False,
|
287 |
min_audio_duration: tp.Optional[float] = None,
|
288 |
+
max_audio_duration: tp.Optional[float] = None
|
|
|
|
|
|
|
289 |
):
|
290 |
+
assert len(meta) > 0, 'No audio meta provided to AudioDataset. Please check loading of audio meta.'
|
291 |
assert segment_duration is None or segment_duration > 0
|
292 |
assert segment_duration is None or min_segment_ratio >= 0
|
293 |
+
logging.debug(f'sample_on_duration: {sample_on_duration}')
|
294 |
+
logging.debug(f'sample_on_weight: {sample_on_weight}')
|
295 |
+
logging.debug(f'pad: {pad}')
|
296 |
+
logging.debug(f'min_segment_ratio: {min_segment_ratio}')
|
297 |
+
|
298 |
self.segment_duration = segment_duration
|
299 |
self.min_segment_ratio = min_segment_ratio
|
300 |
self.max_audio_duration = max_audio_duration
|
|
|
317 |
self.sampling_probabilities = self._get_sampling_probabilities()
|
318 |
self.max_read_retry = max_read_retry
|
319 |
self.return_info = return_info
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
|
321 |
def __len__(self):
|
322 |
return self.num_samples
|
323 |
|
324 |
def _get_sampling_probabilities(self, normalized: bool = True):
|
325 |
+
"""Return the sampling probabilities for each file inside `self.meta`.
|
326 |
+
"""
|
327 |
scores: tp.List[float] = []
|
328 |
for file_meta in self.meta:
|
329 |
score = 1.
|
|
|
337 |
probabilities /= probabilities.sum()
|
338 |
return probabilities
|
339 |
|
340 |
+
def sample_file(self, rng: torch.Generator) -> AudioMeta:
|
341 |
+
"""Sample a given file from `self.meta`. Can be overriden in subclasses.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
This is only called if `segment_duration` is not None.
|
343 |
|
344 |
You must use the provided random number generator `rng` for reproducibility.
|
|
|
345 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
346 |
if not self.sample_on_weight and not self.sample_on_duration:
|
347 |
file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
|
348 |
else:
|
|
|
350 |
|
351 |
return self.meta[file_index]
|
352 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
353 |
def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
|
354 |
if self.segment_duration is None:
|
355 |
file_meta = self.meta[index]
|
|
|
357 |
out = convert_audio(out, sr, self.sample_rate, self.channels)
|
358 |
n_frames = out.shape[-1]
|
359 |
segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
|
360 |
+
sample_rate=self.sample_rate)
|
361 |
else:
|
362 |
rng = torch.Generator()
|
363 |
if self.shuffle:
|
364 |
+
# We use index, plus extra randomness
|
365 |
+
rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
|
|
|
|
|
|
|
|
|
366 |
else:
|
367 |
# We only use index
|
368 |
rng.manual_seed(index)
|
369 |
|
370 |
for retry in range(self.max_read_retry):
|
371 |
+
file_meta = self.sample_file(rng)
|
372 |
# We add some variance in the file position even if audio file is smaller than segment
|
373 |
# without ending up with empty segments
|
374 |
max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
|
|
|
381 |
if self.pad:
|
382 |
out = F.pad(out, (0, target_frames - n_frames))
|
383 |
segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
|
384 |
+
sample_rate=self.sample_rate)
|
385 |
except Exception as exc:
|
386 |
logger.warning("Error opening file %s: %r", file_meta.path, exc)
|
387 |
if retry == self.max_read_retry - 1:
|
|
|
423 |
if to_pad:
|
424 |
# Each wav could be of a different duration as they are not segmented.
|
425 |
for i in range(len(samples)):
|
426 |
+
# Determines the total legth of the signal with padding, so we update here as we pad.
|
427 |
segment_infos[i].total_frames = max_len
|
428 |
wavs[i] = _pad_wav(wavs[i])
|
429 |
|
|
|
436 |
return torch.stack(samples)
|
437 |
|
438 |
def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
|
439 |
+
"""Filters out audio files with short durations.
|
440 |
+
Removes from meta files that have durations that will not allow to samples examples from them.
|
441 |
+
"""
|
442 |
orig_len = len(meta)
|
443 |
|
444 |
# Filter data that is too short.
|
audiocraft/data/audio_utils.py
CHANGED
@@ -3,8 +3,7 @@
|
|
3 |
#
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
and volume normalization."""
|
8 |
import sys
|
9 |
import typing as tp
|
10 |
|
@@ -48,14 +47,15 @@ def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor
|
|
48 |
|
49 |
def convert_audio(wav: torch.Tensor, from_rate: float,
|
50 |
to_rate: float, to_channels: int) -> torch.Tensor:
|
51 |
-
"""Convert audio to new sample rate and number of audio channels.
|
|
|
52 |
wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
|
53 |
wav = convert_audio_channels(wav, to_channels)
|
54 |
return wav
|
55 |
|
56 |
|
57 |
-
def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float =
|
58 |
-
|
59 |
"""Normalize an input signal to a user loudness in dB LKFS.
|
60 |
Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
|
61 |
|
@@ -63,10 +63,9 @@ def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db
|
|
63 |
wav (torch.Tensor): Input multichannel audio data.
|
64 |
sample_rate (int): Sample rate.
|
65 |
loudness_headroom_db (float): Target loudness of the output in dB LUFS.
|
66 |
-
loudness_compressor (bool): Uses tanh for soft clipping.
|
67 |
energy_floor (float): anything below that RMS level will not be rescaled.
|
68 |
Returns:
|
69 |
-
torch.Tensor: Loudness normalized output data.
|
70 |
"""
|
71 |
energy = wav.pow(2).mean().sqrt().item()
|
72 |
if energy < energy_floor:
|
@@ -77,8 +76,6 @@ def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db
|
|
77 |
delta_loudness = -loudness_headroom_db - input_loudness_db
|
78 |
gain = 10.0 ** (delta_loudness / 20.0)
|
79 |
output = gain * wav
|
80 |
-
if loudness_compressor:
|
81 |
-
output = torch.tanh(output)
|
82 |
assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
|
83 |
return output
|
84 |
|
@@ -96,8 +93,7 @@ def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optio
|
|
96 |
def normalize_audio(wav: torch.Tensor, normalize: bool = True,
|
97 |
strategy: str = 'peak', peak_clip_headroom_db: float = 1,
|
98 |
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
|
99 |
-
|
100 |
-
sample_rate: tp.Optional[int] = None,
|
101 |
stem_name: tp.Optional[str] = None) -> torch.Tensor:
|
102 |
"""Normalize the audio according to the prescribed strategy (see after).
|
103 |
|
@@ -113,11 +109,10 @@ def normalize_audio(wav: torch.Tensor, normalize: bool = True,
|
|
113 |
rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
|
114 |
than the `peak_clip` one to avoid further clipping.
|
115 |
loudness_headroom_db (float): Target loudness for loudness normalization.
|
116 |
-
loudness_compressor (bool): If True, uses tanh based soft clipping.
|
117 |
log_clipping (bool): If True, basic logging on stderr when clipping still
|
118 |
occurs despite strategy (only for 'rms').
|
119 |
sample_rate (int): Sample rate for the audio data (required for loudness).
|
120 |
-
stem_name (str
|
121 |
Returns:
|
122 |
torch.Tensor: Normalized audio.
|
123 |
"""
|
@@ -137,7 +132,7 @@ def normalize_audio(wav: torch.Tensor, normalize: bool = True,
|
|
137 |
_clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
|
138 |
elif strategy == 'loudness':
|
139 |
assert sample_rate is not None, "Loudness normalization requires sample rate."
|
140 |
-
wav = normalize_loudness(wav, sample_rate, loudness_headroom_db
|
141 |
_clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
|
142 |
else:
|
143 |
assert wav.abs().max() < 1
|
@@ -150,19 +145,17 @@ def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
|
150 |
"""
|
151 |
if wav.dtype.is_floating_point:
|
152 |
return wav
|
153 |
-
|
|
|
154 |
return wav.float() / 2**15
|
155 |
-
elif wav.dtype == torch.int32:
|
156 |
-
return wav.float() / 2**31
|
157 |
-
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
|
158 |
|
159 |
|
160 |
def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
|
161 |
"""Convert audio to int 16 bits PCM format.
|
162 |
|
163 |
-
..Warning:: There exist many formula for doing this
|
164 |
-
due to the
|
165 |
-
or
|
166 |
it is possible that `i16_pcm(f32_pcm)) != Identity`.
|
167 |
"""
|
168 |
if wav.dtype.is_floating_point:
|
|
|
3 |
#
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
|
|
7 |
import sys
|
8 |
import typing as tp
|
9 |
|
|
|
47 |
|
48 |
def convert_audio(wav: torch.Tensor, from_rate: float,
|
49 |
to_rate: float, to_channels: int) -> torch.Tensor:
|
50 |
+
"""Convert audio to new sample rate and number of audio channels.
|
51 |
+
"""
|
52 |
wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
|
53 |
wav = convert_audio_channels(wav, to_channels)
|
54 |
return wav
|
55 |
|
56 |
|
57 |
+
def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 12,
|
58 |
+
energy_floor: float = 2e-3):
|
59 |
"""Normalize an input signal to a user loudness in dB LKFS.
|
60 |
Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
|
61 |
|
|
|
63 |
wav (torch.Tensor): Input multichannel audio data.
|
64 |
sample_rate (int): Sample rate.
|
65 |
loudness_headroom_db (float): Target loudness of the output in dB LUFS.
|
|
|
66 |
energy_floor (float): anything below that RMS level will not be rescaled.
|
67 |
Returns:
|
68 |
+
output (torch.Tensor): Loudness normalized output data.
|
69 |
"""
|
70 |
energy = wav.pow(2).mean().sqrt().item()
|
71 |
if energy < energy_floor:
|
|
|
76 |
delta_loudness = -loudness_headroom_db - input_loudness_db
|
77 |
gain = 10.0 ** (delta_loudness / 20.0)
|
78 |
output = gain * wav
|
|
|
|
|
79 |
assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
|
80 |
return output
|
81 |
|
|
|
93 |
def normalize_audio(wav: torch.Tensor, normalize: bool = True,
|
94 |
strategy: str = 'peak', peak_clip_headroom_db: float = 1,
|
95 |
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
|
96 |
+
log_clipping: bool = False, sample_rate: tp.Optional[int] = None,
|
|
|
97 |
stem_name: tp.Optional[str] = None) -> torch.Tensor:
|
98 |
"""Normalize the audio according to the prescribed strategy (see after).
|
99 |
|
|
|
109 |
rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
|
110 |
than the `peak_clip` one to avoid further clipping.
|
111 |
loudness_headroom_db (float): Target loudness for loudness normalization.
|
|
|
112 |
log_clipping (bool): If True, basic logging on stderr when clipping still
|
113 |
occurs despite strategy (only for 'rms').
|
114 |
sample_rate (int): Sample rate for the audio data (required for loudness).
|
115 |
+
stem_name (Optional[str]): Stem name for clipping logging.
|
116 |
Returns:
|
117 |
torch.Tensor: Normalized audio.
|
118 |
"""
|
|
|
132 |
_clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
|
133 |
elif strategy == 'loudness':
|
134 |
assert sample_rate is not None, "Loudness normalization requires sample rate."
|
135 |
+
wav = normalize_loudness(wav, sample_rate, loudness_headroom_db)
|
136 |
_clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
|
137 |
else:
|
138 |
assert wav.abs().max() < 1
|
|
|
145 |
"""
|
146 |
if wav.dtype.is_floating_point:
|
147 |
return wav
|
148 |
+
else:
|
149 |
+
assert wav.dtype == torch.int16
|
150 |
return wav.float() / 2**15
|
|
|
|
|
|
|
151 |
|
152 |
|
153 |
def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
|
154 |
"""Convert audio to int 16 bits PCM format.
|
155 |
|
156 |
+
..Warning:: There exist many formula for doing this convertion. None are perfect
|
157 |
+
due to the asymetry of the int16 range. One either have possible clipping, DC offset,
|
158 |
+
or inconsistancies with f32_pcm. If the given wav doesn't have enough headroom,
|
159 |
it is possible that `i16_pcm(f32_pcm)) != Identity`.
|
160 |
"""
|
161 |
if wav.dtype.is_floating_point:
|
audiocraft/data/info_audio_dataset.py
DELETED
@@ -1,110 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
"""Base classes for the datasets that also provide non-audio metadata,
|
7 |
-
e.g. description, text transcription etc.
|
8 |
-
"""
|
9 |
-
from dataclasses import dataclass
|
10 |
-
import logging
|
11 |
-
import math
|
12 |
-
import re
|
13 |
-
import typing as tp
|
14 |
-
|
15 |
-
import torch
|
16 |
-
|
17 |
-
from .audio_dataset import AudioDataset, AudioMeta
|
18 |
-
from ..environment import AudioCraftEnvironment
|
19 |
-
from ..modules.conditioners import SegmentWithAttributes, ConditioningAttributes
|
20 |
-
|
21 |
-
|
22 |
-
logger = logging.getLogger(__name__)
|
23 |
-
|
24 |
-
|
25 |
-
def _clusterify_meta(meta: AudioMeta) -> AudioMeta:
|
26 |
-
"""Monkey-patch meta to match cluster specificities."""
|
27 |
-
meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path)
|
28 |
-
if meta.info_path is not None:
|
29 |
-
meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path)
|
30 |
-
return meta
|
31 |
-
|
32 |
-
|
33 |
-
def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
|
34 |
-
"""Monkey-patch all meta to match cluster specificities."""
|
35 |
-
return [_clusterify_meta(m) for m in meta]
|
36 |
-
|
37 |
-
|
38 |
-
@dataclass
|
39 |
-
class AudioInfo(SegmentWithAttributes):
|
40 |
-
"""Dummy SegmentInfo with empty attributes.
|
41 |
-
|
42 |
-
The InfoAudioDataset is expected to return metadata that inherits
|
43 |
-
from SegmentWithAttributes class and can return conditioning attributes.
|
44 |
-
|
45 |
-
This basically guarantees all datasets will be compatible with current
|
46 |
-
solver that contain conditioners requiring this.
|
47 |
-
"""
|
48 |
-
audio_tokens: tp.Optional[torch.Tensor] = None # populated when using cached batch for training a LM.
|
49 |
-
|
50 |
-
def to_condition_attributes(self) -> ConditioningAttributes:
|
51 |
-
return ConditioningAttributes()
|
52 |
-
|
53 |
-
|
54 |
-
class InfoAudioDataset(AudioDataset):
|
55 |
-
"""AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform.
|
56 |
-
|
57 |
-
See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments.
|
58 |
-
"""
|
59 |
-
def __init__(self, meta: tp.List[AudioMeta], **kwargs):
|
60 |
-
super().__init__(clusterify_all_meta(meta), **kwargs)
|
61 |
-
|
62 |
-
def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]:
|
63 |
-
if not self.return_info:
|
64 |
-
wav = super().__getitem__(index)
|
65 |
-
assert isinstance(wav, torch.Tensor)
|
66 |
-
return wav
|
67 |
-
wav, meta = super().__getitem__(index)
|
68 |
-
return wav, AudioInfo(**meta.to_dict())
|
69 |
-
|
70 |
-
|
71 |
-
def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]:
|
72 |
-
"""Preprocess a single keyword or possible a list of keywords."""
|
73 |
-
if isinstance(value, list):
|
74 |
-
return get_keyword_list(value)
|
75 |
-
else:
|
76 |
-
return get_keyword(value)
|
77 |
-
|
78 |
-
|
79 |
-
def get_string(value: tp.Optional[str]) -> tp.Optional[str]:
|
80 |
-
"""Preprocess a single keyword."""
|
81 |
-
if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
|
82 |
-
return None
|
83 |
-
else:
|
84 |
-
return value.strip()
|
85 |
-
|
86 |
-
|
87 |
-
def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]:
|
88 |
-
"""Preprocess a single keyword."""
|
89 |
-
if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
|
90 |
-
return None
|
91 |
-
else:
|
92 |
-
return value.strip().lower()
|
93 |
-
|
94 |
-
|
95 |
-
def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]:
|
96 |
-
"""Preprocess a list of keywords."""
|
97 |
-
if isinstance(values, str):
|
98 |
-
values = [v.strip() for v in re.split(r'[,\s]', values)]
|
99 |
-
elif isinstance(values, float) and math.isnan(values):
|
100 |
-
values = []
|
101 |
-
if not isinstance(values, list):
|
102 |
-
logger.debug(f"Unexpected keyword list {values}")
|
103 |
-
values = [str(values)]
|
104 |
-
|
105 |
-
kws = [get_keyword(v) for v in values]
|
106 |
-
kw_list = [k for k in kws if k is not None]
|
107 |
-
if len(kw_list) == 0:
|
108 |
-
return None
|
109 |
-
else:
|
110 |
-
return kw_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/data/music_dataset.py
DELETED
@@ -1,270 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
"""Dataset of music tracks with rich metadata.
|
7 |
-
"""
|
8 |
-
from dataclasses import dataclass, field, fields, replace
|
9 |
-
import gzip
|
10 |
-
import json
|
11 |
-
import logging
|
12 |
-
from pathlib import Path
|
13 |
-
import random
|
14 |
-
import typing as tp
|
15 |
-
|
16 |
-
import torch
|
17 |
-
|
18 |
-
from .info_audio_dataset import (
|
19 |
-
InfoAudioDataset,
|
20 |
-
AudioInfo,
|
21 |
-
get_keyword_list,
|
22 |
-
get_keyword,
|
23 |
-
get_string
|
24 |
-
)
|
25 |
-
from ..modules.conditioners import (
|
26 |
-
ConditioningAttributes,
|
27 |
-
JointEmbedCondition,
|
28 |
-
WavCondition,
|
29 |
-
)
|
30 |
-
from ..utils.utils import warn_once
|
31 |
-
|
32 |
-
|
33 |
-
logger = logging.getLogger(__name__)
|
34 |
-
|
35 |
-
|
36 |
-
@dataclass
|
37 |
-
class MusicInfo(AudioInfo):
|
38 |
-
"""Segment info augmented with music metadata.
|
39 |
-
"""
|
40 |
-
# music-specific metadata
|
41 |
-
title: tp.Optional[str] = None
|
42 |
-
artist: tp.Optional[str] = None # anonymized artist id, used to ensure no overlap between splits
|
43 |
-
key: tp.Optional[str] = None
|
44 |
-
bpm: tp.Optional[float] = None
|
45 |
-
genre: tp.Optional[str] = None
|
46 |
-
moods: tp.Optional[list] = None
|
47 |
-
keywords: tp.Optional[list] = None
|
48 |
-
description: tp.Optional[str] = None
|
49 |
-
name: tp.Optional[str] = None
|
50 |
-
instrument: tp.Optional[str] = None
|
51 |
-
# original wav accompanying the metadata
|
52 |
-
self_wav: tp.Optional[WavCondition] = None
|
53 |
-
# dict mapping attributes names to tuple of wav, text and metadata
|
54 |
-
joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
|
55 |
-
|
56 |
-
@property
|
57 |
-
def has_music_meta(self) -> bool:
|
58 |
-
return self.name is not None
|
59 |
-
|
60 |
-
def to_condition_attributes(self) -> ConditioningAttributes:
|
61 |
-
out = ConditioningAttributes()
|
62 |
-
for _field in fields(self):
|
63 |
-
key, value = _field.name, getattr(self, _field.name)
|
64 |
-
if key == 'self_wav':
|
65 |
-
out.wav[key] = value
|
66 |
-
elif key == 'joint_embed':
|
67 |
-
for embed_attribute, embed_cond in value.items():
|
68 |
-
out.joint_embed[embed_attribute] = embed_cond
|
69 |
-
else:
|
70 |
-
if isinstance(value, list):
|
71 |
-
value = ' '.join(value)
|
72 |
-
out.text[key] = value
|
73 |
-
return out
|
74 |
-
|
75 |
-
@staticmethod
|
76 |
-
def attribute_getter(attribute):
|
77 |
-
if attribute == 'bpm':
|
78 |
-
preprocess_func = get_bpm
|
79 |
-
elif attribute == 'key':
|
80 |
-
preprocess_func = get_musical_key
|
81 |
-
elif attribute in ['moods', 'keywords']:
|
82 |
-
preprocess_func = get_keyword_list
|
83 |
-
elif attribute in ['genre', 'name', 'instrument']:
|
84 |
-
preprocess_func = get_keyword
|
85 |
-
elif attribute in ['title', 'artist', 'description']:
|
86 |
-
preprocess_func = get_string
|
87 |
-
else:
|
88 |
-
preprocess_func = None
|
89 |
-
return preprocess_func
|
90 |
-
|
91 |
-
@classmethod
|
92 |
-
def from_dict(cls, dictionary: dict, fields_required: bool = False):
|
93 |
-
_dictionary: tp.Dict[str, tp.Any] = {}
|
94 |
-
|
95 |
-
# allow a subset of attributes to not be loaded from the dictionary
|
96 |
-
# these attributes may be populated later
|
97 |
-
post_init_attributes = ['self_wav', 'joint_embed']
|
98 |
-
optional_fields = ['keywords']
|
99 |
-
|
100 |
-
for _field in fields(cls):
|
101 |
-
if _field.name in post_init_attributes:
|
102 |
-
continue
|
103 |
-
elif _field.name not in dictionary:
|
104 |
-
if fields_required and _field.name not in optional_fields:
|
105 |
-
raise KeyError(f"Unexpected missing key: {_field.name}")
|
106 |
-
else:
|
107 |
-
preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
|
108 |
-
value = dictionary[_field.name]
|
109 |
-
if preprocess_func:
|
110 |
-
value = preprocess_func(value)
|
111 |
-
_dictionary[_field.name] = value
|
112 |
-
return cls(**_dictionary)
|
113 |
-
|
114 |
-
|
115 |
-
def augment_music_info_description(music_info: MusicInfo, merge_text_p: float = 0.,
|
116 |
-
drop_desc_p: float = 0., drop_other_p: float = 0.) -> MusicInfo:
|
117 |
-
"""Augment MusicInfo description with additional metadata fields and potential dropout.
|
118 |
-
Additional textual attributes are added given probability 'merge_text_conditions_p' and
|
119 |
-
the original textual description is dropped from the augmented description given probability drop_desc_p.
|
120 |
-
|
121 |
-
Args:
|
122 |
-
music_info (MusicInfo): The music metadata to augment.
|
123 |
-
merge_text_p (float): Probability of merging additional metadata to the description.
|
124 |
-
If provided value is 0, then no merging is performed.
|
125 |
-
drop_desc_p (float): Probability of dropping the original description on text merge.
|
126 |
-
if provided value is 0, then no drop out is performed.
|
127 |
-
drop_other_p (float): Probability of dropping the other fields used for text augmentation.
|
128 |
-
Returns:
|
129 |
-
MusicInfo: The MusicInfo with augmented textual description.
|
130 |
-
"""
|
131 |
-
def is_valid_field(field_name: str, field_value: tp.Any) -> bool:
|
132 |
-
valid_field_name = field_name in ['key', 'bpm', 'genre', 'moods', 'instrument', 'keywords']
|
133 |
-
valid_field_value = field_value is not None and isinstance(field_value, (int, float, str, list))
|
134 |
-
keep_field = random.uniform(0, 1) < drop_other_p
|
135 |
-
return valid_field_name and valid_field_value and keep_field
|
136 |
-
|
137 |
-
def process_value(v: tp.Any) -> str:
|
138 |
-
if isinstance(v, (int, float, str)):
|
139 |
-
return str(v)
|
140 |
-
if isinstance(v, list):
|
141 |
-
return ", ".join(v)
|
142 |
-
else:
|
143 |
-
raise ValueError(f"Unknown type for text value! ({type(v), v})")
|
144 |
-
|
145 |
-
description = music_info.description
|
146 |
-
|
147 |
-
metadata_text = ""
|
148 |
-
if random.uniform(0, 1) < merge_text_p:
|
149 |
-
meta_pairs = [f'{_field.name}: {process_value(getattr(music_info, _field.name))}'
|
150 |
-
for _field in fields(music_info) if is_valid_field(_field.name, getattr(music_info, _field.name))]
|
151 |
-
random.shuffle(meta_pairs)
|
152 |
-
metadata_text = ". ".join(meta_pairs)
|
153 |
-
description = description if not random.uniform(0, 1) < drop_desc_p else None
|
154 |
-
logger.debug(f"Applying text augmentation on MMI info. description: {description}, metadata: {metadata_text}")
|
155 |
-
|
156 |
-
if description is None:
|
157 |
-
description = metadata_text if len(metadata_text) > 1 else None
|
158 |
-
else:
|
159 |
-
description = ". ".join([description.rstrip('.'), metadata_text])
|
160 |
-
description = description.strip() if description else None
|
161 |
-
|
162 |
-
music_info = replace(music_info)
|
163 |
-
music_info.description = description
|
164 |
-
return music_info
|
165 |
-
|
166 |
-
|
167 |
-
class Paraphraser:
|
168 |
-
def __init__(self, paraphrase_source: tp.Union[str, Path], paraphrase_p: float = 0.):
|
169 |
-
self.paraphrase_p = paraphrase_p
|
170 |
-
open_fn = gzip.open if str(paraphrase_source).lower().endswith('.gz') else open
|
171 |
-
with open_fn(paraphrase_source, 'rb') as f: # type: ignore
|
172 |
-
self.paraphrase_source = json.loads(f.read())
|
173 |
-
logger.info(f"loaded paraphrasing source from: {paraphrase_source}")
|
174 |
-
|
175 |
-
def sample_paraphrase(self, audio_path: str, description: str):
|
176 |
-
if random.random() >= self.paraphrase_p:
|
177 |
-
return description
|
178 |
-
info_path = Path(audio_path).with_suffix('.json')
|
179 |
-
if info_path not in self.paraphrase_source:
|
180 |
-
warn_once(logger, f"{info_path} not in paraphrase source!")
|
181 |
-
return description
|
182 |
-
new_desc = random.choice(self.paraphrase_source[info_path])
|
183 |
-
logger.debug(f"{description} -> {new_desc}")
|
184 |
-
return new_desc
|
185 |
-
|
186 |
-
|
187 |
-
class MusicDataset(InfoAudioDataset):
|
188 |
-
"""Music dataset is an AudioDataset with music-related metadata.
|
189 |
-
|
190 |
-
Args:
|
191 |
-
info_fields_required (bool): Whether to enforce having required fields.
|
192 |
-
merge_text_p (float): Probability of merging additional metadata to the description.
|
193 |
-
drop_desc_p (float): Probability of dropping the original description on text merge.
|
194 |
-
drop_other_p (float): Probability of dropping the other fields used for text augmentation.
|
195 |
-
joint_embed_attributes (list[str]): A list of attributes for which joint embedding metadata is returned.
|
196 |
-
paraphrase_source (str, optional): Path to the .json or .json.gz file containing the
|
197 |
-
paraphrases for the description. The json should be a dict with keys are the
|
198 |
-
original info path (e.g. track_path.json) and each value is a list of possible
|
199 |
-
paraphrased.
|
200 |
-
paraphrase_p (float): probability of taking a paraphrase.
|
201 |
-
|
202 |
-
See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
|
203 |
-
"""
|
204 |
-
def __init__(self, *args, info_fields_required: bool = True,
|
205 |
-
merge_text_p: float = 0., drop_desc_p: float = 0., drop_other_p: float = 0.,
|
206 |
-
joint_embed_attributes: tp.List[str] = [],
|
207 |
-
paraphrase_source: tp.Optional[str] = None, paraphrase_p: float = 0,
|
208 |
-
**kwargs):
|
209 |
-
kwargs['return_info'] = True # We require the info for each song of the dataset.
|
210 |
-
super().__init__(*args, **kwargs)
|
211 |
-
self.info_fields_required = info_fields_required
|
212 |
-
self.merge_text_p = merge_text_p
|
213 |
-
self.drop_desc_p = drop_desc_p
|
214 |
-
self.drop_other_p = drop_other_p
|
215 |
-
self.joint_embed_attributes = joint_embed_attributes
|
216 |
-
self.paraphraser = None
|
217 |
-
if paraphrase_source is not None:
|
218 |
-
self.paraphraser = Paraphraser(paraphrase_source, paraphrase_p)
|
219 |
-
|
220 |
-
def __getitem__(self, index):
|
221 |
-
wav, info = super().__getitem__(index)
|
222 |
-
info_data = info.to_dict()
|
223 |
-
music_info_path = Path(info.meta.path).with_suffix('.json')
|
224 |
-
|
225 |
-
if Path(music_info_path).exists():
|
226 |
-
with open(music_info_path, 'r') as json_file:
|
227 |
-
music_data = json.load(json_file)
|
228 |
-
music_data.update(info_data)
|
229 |
-
music_info = MusicInfo.from_dict(music_data, fields_required=self.info_fields_required)
|
230 |
-
if self.paraphraser is not None:
|
231 |
-
music_info.description = self.paraphraser.sample(music_info.meta.path, music_info.description)
|
232 |
-
if self.merge_text_p:
|
233 |
-
music_info = augment_music_info_description(
|
234 |
-
music_info, self.merge_text_p, self.drop_desc_p, self.drop_other_p)
|
235 |
-
else:
|
236 |
-
music_info = MusicInfo.from_dict(info_data, fields_required=False)
|
237 |
-
|
238 |
-
music_info.self_wav = WavCondition(
|
239 |
-
wav=wav[None], length=torch.tensor([info.n_frames]),
|
240 |
-
sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
|
241 |
-
|
242 |
-
for att in self.joint_embed_attributes:
|
243 |
-
att_value = getattr(music_info, att)
|
244 |
-
joint_embed_cond = JointEmbedCondition(
|
245 |
-
wav[None], [att_value], torch.tensor([info.n_frames]),
|
246 |
-
sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
|
247 |
-
music_info.joint_embed[att] = joint_embed_cond
|
248 |
-
|
249 |
-
return wav, music_info
|
250 |
-
|
251 |
-
|
252 |
-
def get_musical_key(value: tp.Optional[str]) -> tp.Optional[str]:
|
253 |
-
"""Preprocess key keywords, discarding them if there are multiple key defined."""
|
254 |
-
if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
|
255 |
-
return None
|
256 |
-
elif ',' in value:
|
257 |
-
# For now, we discard when multiple keys are defined separated with comas
|
258 |
-
return None
|
259 |
-
else:
|
260 |
-
return value.strip().lower()
|
261 |
-
|
262 |
-
|
263 |
-
def get_bpm(value: tp.Optional[str]) -> tp.Optional[float]:
|
264 |
-
"""Preprocess to a float."""
|
265 |
-
if value is None:
|
266 |
-
return None
|
267 |
-
try:
|
268 |
-
return float(value)
|
269 |
-
except ValueError:
|
270 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/data/sound_dataset.py
DELETED
@@ -1,330 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
"""Dataset of audio with a simple description.
|
7 |
-
"""
|
8 |
-
|
9 |
-
from dataclasses import dataclass, fields, replace
|
10 |
-
import json
|
11 |
-
from pathlib import Path
|
12 |
-
import random
|
13 |
-
import typing as tp
|
14 |
-
|
15 |
-
import numpy as np
|
16 |
-
import torch
|
17 |
-
|
18 |
-
from .info_audio_dataset import (
|
19 |
-
InfoAudioDataset,
|
20 |
-
get_keyword_or_keyword_list
|
21 |
-
)
|
22 |
-
from ..modules.conditioners import (
|
23 |
-
ConditioningAttributes,
|
24 |
-
SegmentWithAttributes,
|
25 |
-
WavCondition,
|
26 |
-
)
|
27 |
-
|
28 |
-
|
29 |
-
EPS = torch.finfo(torch.float32).eps
|
30 |
-
TARGET_LEVEL_LOWER = -35
|
31 |
-
TARGET_LEVEL_UPPER = -15
|
32 |
-
|
33 |
-
|
34 |
-
@dataclass
|
35 |
-
class SoundInfo(SegmentWithAttributes):
|
36 |
-
"""Segment info augmented with Sound metadata.
|
37 |
-
"""
|
38 |
-
description: tp.Optional[str] = None
|
39 |
-
self_wav: tp.Optional[torch.Tensor] = None
|
40 |
-
|
41 |
-
@property
|
42 |
-
def has_sound_meta(self) -> bool:
|
43 |
-
return self.description is not None
|
44 |
-
|
45 |
-
def to_condition_attributes(self) -> ConditioningAttributes:
|
46 |
-
out = ConditioningAttributes()
|
47 |
-
|
48 |
-
for _field in fields(self):
|
49 |
-
key, value = _field.name, getattr(self, _field.name)
|
50 |
-
if key == 'self_wav':
|
51 |
-
out.wav[key] = value
|
52 |
-
else:
|
53 |
-
out.text[key] = value
|
54 |
-
return out
|
55 |
-
|
56 |
-
@staticmethod
|
57 |
-
def attribute_getter(attribute):
|
58 |
-
if attribute == 'description':
|
59 |
-
preprocess_func = get_keyword_or_keyword_list
|
60 |
-
else:
|
61 |
-
preprocess_func = None
|
62 |
-
return preprocess_func
|
63 |
-
|
64 |
-
@classmethod
|
65 |
-
def from_dict(cls, dictionary: dict, fields_required: bool = False):
|
66 |
-
_dictionary: tp.Dict[str, tp.Any] = {}
|
67 |
-
|
68 |
-
# allow a subset of attributes to not be loaded from the dictionary
|
69 |
-
# these attributes may be populated later
|
70 |
-
post_init_attributes = ['self_wav']
|
71 |
-
|
72 |
-
for _field in fields(cls):
|
73 |
-
if _field.name in post_init_attributes:
|
74 |
-
continue
|
75 |
-
elif _field.name not in dictionary:
|
76 |
-
if fields_required:
|
77 |
-
raise KeyError(f"Unexpected missing key: {_field.name}")
|
78 |
-
else:
|
79 |
-
preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
|
80 |
-
value = dictionary[_field.name]
|
81 |
-
if preprocess_func:
|
82 |
-
value = preprocess_func(value)
|
83 |
-
_dictionary[_field.name] = value
|
84 |
-
return cls(**_dictionary)
|
85 |
-
|
86 |
-
|
87 |
-
class SoundDataset(InfoAudioDataset):
|
88 |
-
"""Sound audio dataset: Audio dataset with environmental sound-specific metadata.
|
89 |
-
|
90 |
-
Args:
|
91 |
-
info_fields_required (bool): Whether all the mandatory metadata fields should be in the loaded metadata.
|
92 |
-
external_metadata_source (tp.Optional[str]): Folder containing JSON metadata for the corresponding dataset.
|
93 |
-
The metadata files contained in this folder are expected to match the stem of the audio file with
|
94 |
-
a json extension.
|
95 |
-
aug_p (float): Probability of performing audio mixing augmentation on the batch.
|
96 |
-
mix_p (float): Proportion of batch items that are mixed together when applying audio mixing augmentation.
|
97 |
-
mix_snr_low (int): Lowerbound for SNR value sampled for mixing augmentation.
|
98 |
-
mix_snr_high (int): Upperbound for SNR value sampled for mixing augmentation.
|
99 |
-
mix_min_overlap (float): Minimum overlap between audio files when performing mixing augmentation.
|
100 |
-
kwargs: Additional arguments for AudioDataset.
|
101 |
-
|
102 |
-
See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
|
103 |
-
"""
|
104 |
-
def __init__(
|
105 |
-
self,
|
106 |
-
*args,
|
107 |
-
info_fields_required: bool = True,
|
108 |
-
external_metadata_source: tp.Optional[str] = None,
|
109 |
-
aug_p: float = 0.,
|
110 |
-
mix_p: float = 0.,
|
111 |
-
mix_snr_low: int = -5,
|
112 |
-
mix_snr_high: int = 5,
|
113 |
-
mix_min_overlap: float = 0.5,
|
114 |
-
**kwargs
|
115 |
-
):
|
116 |
-
kwargs['return_info'] = True # We require the info for each song of the dataset.
|
117 |
-
super().__init__(*args, **kwargs)
|
118 |
-
self.info_fields_required = info_fields_required
|
119 |
-
self.external_metadata_source = external_metadata_source
|
120 |
-
self.aug_p = aug_p
|
121 |
-
self.mix_p = mix_p
|
122 |
-
if self.aug_p > 0:
|
123 |
-
assert self.mix_p > 0, "Expecting some mixing proportion mix_p if aug_p > 0"
|
124 |
-
assert self.channels == 1, "SoundDataset with audio mixing considers only monophonic audio"
|
125 |
-
self.mix_snr_low = mix_snr_low
|
126 |
-
self.mix_snr_high = mix_snr_high
|
127 |
-
self.mix_min_overlap = mix_min_overlap
|
128 |
-
|
129 |
-
def _get_info_path(self, path: tp.Union[str, Path]) -> Path:
|
130 |
-
"""Get path of JSON with metadata (description, etc.).
|
131 |
-
If there exists a JSON with the same name as 'path.name', then it will be used.
|
132 |
-
Else, such JSON will be searched for in an external json source folder if it exists.
|
133 |
-
"""
|
134 |
-
info_path = Path(path).with_suffix('.json')
|
135 |
-
if Path(info_path).exists():
|
136 |
-
return info_path
|
137 |
-
elif self.external_metadata_source and (Path(self.external_metadata_source) / info_path.name).exists():
|
138 |
-
return Path(self.external_metadata_source) / info_path.name
|
139 |
-
else:
|
140 |
-
raise Exception(f"Unable to find a metadata JSON for path: {path}")
|
141 |
-
|
142 |
-
def __getitem__(self, index):
|
143 |
-
wav, info = super().__getitem__(index)
|
144 |
-
info_data = info.to_dict()
|
145 |
-
info_path = self._get_info_path(info.meta.path)
|
146 |
-
if Path(info_path).exists():
|
147 |
-
with open(info_path, 'r') as json_file:
|
148 |
-
sound_data = json.load(json_file)
|
149 |
-
sound_data.update(info_data)
|
150 |
-
sound_info = SoundInfo.from_dict(sound_data, fields_required=self.info_fields_required)
|
151 |
-
# if there are multiple descriptions, sample one randomly
|
152 |
-
if isinstance(sound_info.description, list):
|
153 |
-
sound_info.description = random.choice(sound_info.description)
|
154 |
-
else:
|
155 |
-
sound_info = SoundInfo.from_dict(info_data, fields_required=False)
|
156 |
-
|
157 |
-
sound_info.self_wav = WavCondition(
|
158 |
-
wav=wav[None], length=torch.tensor([info.n_frames]),
|
159 |
-
sample_rate=[sound_info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
|
160 |
-
|
161 |
-
return wav, sound_info
|
162 |
-
|
163 |
-
def collater(self, samples):
|
164 |
-
# when training, audio mixing is performed in the collate function
|
165 |
-
wav, sound_info = super().collater(samples) # SoundDataset always returns infos
|
166 |
-
if self.aug_p > 0:
|
167 |
-
wav, sound_info = mix_samples(wav, sound_info, self.aug_p, self.mix_p,
|
168 |
-
snr_low=self.mix_snr_low, snr_high=self.mix_snr_high,
|
169 |
-
min_overlap=self.mix_min_overlap)
|
170 |
-
return wav, sound_info
|
171 |
-
|
172 |
-
|
173 |
-
def rms_f(x: torch.Tensor) -> torch.Tensor:
|
174 |
-
return (x ** 2).mean(1).pow(0.5)
|
175 |
-
|
176 |
-
|
177 |
-
def normalize(audio: torch.Tensor, target_level: int = -25) -> torch.Tensor:
|
178 |
-
"""Normalize the signal to the target level."""
|
179 |
-
rms = rms_f(audio)
|
180 |
-
scalar = 10 ** (target_level / 20) / (rms + EPS)
|
181 |
-
audio = audio * scalar.unsqueeze(1)
|
182 |
-
return audio
|
183 |
-
|
184 |
-
|
185 |
-
def is_clipped(audio: torch.Tensor, clipping_threshold: float = 0.99) -> torch.Tensor:
|
186 |
-
return (abs(audio) > clipping_threshold).any(1)
|
187 |
-
|
188 |
-
|
189 |
-
def mix_pair(src: torch.Tensor, dst: torch.Tensor, min_overlap: float) -> torch.Tensor:
|
190 |
-
start = random.randint(0, int(src.shape[1] * (1 - min_overlap)))
|
191 |
-
remainder = src.shape[1] - start
|
192 |
-
if dst.shape[1] > remainder:
|
193 |
-
src[:, start:] = src[:, start:] + dst[:, :remainder]
|
194 |
-
else:
|
195 |
-
src[:, start:start+dst.shape[1]] = src[:, start:start+dst.shape[1]] + dst
|
196 |
-
return src
|
197 |
-
|
198 |
-
|
199 |
-
def snr_mixer(clean: torch.Tensor, noise: torch.Tensor, snr: int, min_overlap: float,
|
200 |
-
target_level: int = -25, clipping_threshold: float = 0.99) -> torch.Tensor:
|
201 |
-
"""Function to mix clean speech and noise at various SNR levels.
|
202 |
-
|
203 |
-
Args:
|
204 |
-
clean (torch.Tensor): Clean audio source to mix, of shape [B, T].
|
205 |
-
noise (torch.Tensor): Noise audio source to mix, of shape [B, T].
|
206 |
-
snr (int): SNR level when mixing.
|
207 |
-
min_overlap (float): Minimum overlap between the two mixed sources.
|
208 |
-
target_level (int): Gain level in dB.
|
209 |
-
clipping_threshold (float): Threshold for clipping the audio.
|
210 |
-
Returns:
|
211 |
-
torch.Tensor: The mixed audio, of shape [B, T].
|
212 |
-
"""
|
213 |
-
if clean.shape[1] > noise.shape[1]:
|
214 |
-
noise = torch.nn.functional.pad(noise, (0, clean.shape[1] - noise.shape[1]))
|
215 |
-
else:
|
216 |
-
noise = noise[:, :clean.shape[1]]
|
217 |
-
|
218 |
-
# normalizing to -25 dB FS
|
219 |
-
clean = clean / (clean.max(1)[0].abs().unsqueeze(1) + EPS)
|
220 |
-
clean = normalize(clean, target_level)
|
221 |
-
rmsclean = rms_f(clean)
|
222 |
-
|
223 |
-
noise = noise / (noise.max(1)[0].abs().unsqueeze(1) + EPS)
|
224 |
-
noise = normalize(noise, target_level)
|
225 |
-
rmsnoise = rms_f(noise)
|
226 |
-
|
227 |
-
# set the noise level for a given SNR
|
228 |
-
noisescalar = (rmsclean / (10 ** (snr / 20)) / (rmsnoise + EPS)).unsqueeze(1)
|
229 |
-
noisenewlevel = noise * noisescalar
|
230 |
-
|
231 |
-
# mix noise and clean speech
|
232 |
-
noisyspeech = mix_pair(clean, noisenewlevel, min_overlap)
|
233 |
-
|
234 |
-
# randomly select RMS value between -15 dBFS and -35 dBFS and normalize noisyspeech with that value
|
235 |
-
# there is a chance of clipping that might happen with very less probability, which is not a major issue.
|
236 |
-
noisy_rms_level = np.random.randint(TARGET_LEVEL_LOWER, TARGET_LEVEL_UPPER)
|
237 |
-
rmsnoisy = rms_f(noisyspeech)
|
238 |
-
scalarnoisy = (10 ** (noisy_rms_level / 20) / (rmsnoisy + EPS)).unsqueeze(1)
|
239 |
-
noisyspeech = noisyspeech * scalarnoisy
|
240 |
-
clean = clean * scalarnoisy
|
241 |
-
noisenewlevel = noisenewlevel * scalarnoisy
|
242 |
-
|
243 |
-
# final check to see if there are any amplitudes exceeding +/- 1. If so, normalize all the signals accordingly
|
244 |
-
clipped = is_clipped(noisyspeech)
|
245 |
-
if clipped.any():
|
246 |
-
noisyspeech_maxamplevel = noisyspeech[clipped].max(1)[0].abs().unsqueeze(1) / (clipping_threshold - EPS)
|
247 |
-
noisyspeech[clipped] = noisyspeech[clipped] / noisyspeech_maxamplevel
|
248 |
-
|
249 |
-
return noisyspeech
|
250 |
-
|
251 |
-
|
252 |
-
def snr_mix(src: torch.Tensor, dst: torch.Tensor, snr_low: int, snr_high: int, min_overlap: float):
|
253 |
-
if snr_low == snr_high:
|
254 |
-
snr = snr_low
|
255 |
-
else:
|
256 |
-
snr = np.random.randint(snr_low, snr_high)
|
257 |
-
mix = snr_mixer(src, dst, snr, min_overlap)
|
258 |
-
return mix
|
259 |
-
|
260 |
-
|
261 |
-
def mix_text(src_text: str, dst_text: str):
|
262 |
-
"""Mix text from different sources by concatenating them."""
|
263 |
-
if src_text == dst_text:
|
264 |
-
return src_text
|
265 |
-
return src_text + " " + dst_text
|
266 |
-
|
267 |
-
|
268 |
-
def mix_samples(wavs: torch.Tensor, infos: tp.List[SoundInfo], aug_p: float, mix_p: float,
|
269 |
-
snr_low: int, snr_high: int, min_overlap: float):
|
270 |
-
"""Mix samples within a batch, summing the waveforms and concatenating the text infos.
|
271 |
-
|
272 |
-
Args:
|
273 |
-
wavs (torch.Tensor): Audio tensors of shape [B, C, T].
|
274 |
-
infos (list[SoundInfo]): List of SoundInfo items corresponding to the audio.
|
275 |
-
aug_p (float): Augmentation probability.
|
276 |
-
mix_p (float): Proportion of items in the batch to mix (and merge) together.
|
277 |
-
snr_low (int): Lowerbound for sampling SNR.
|
278 |
-
snr_high (int): Upperbound for sampling SNR.
|
279 |
-
min_overlap (float): Minimum overlap between mixed samples.
|
280 |
-
Returns:
|
281 |
-
tuple[torch.Tensor, list[SoundInfo]]: A tuple containing the mixed wavs
|
282 |
-
and mixed SoundInfo for the given batch.
|
283 |
-
"""
|
284 |
-
# no mixing to perform within the batch
|
285 |
-
if mix_p == 0:
|
286 |
-
return wavs, infos
|
287 |
-
|
288 |
-
if random.uniform(0, 1) < aug_p:
|
289 |
-
# perform all augmentations on waveforms as [B, T]
|
290 |
-
# randomly picking pairs of audio to mix
|
291 |
-
assert wavs.size(1) == 1, f"Mix samples requires monophonic audio but C={wavs.size(1)}"
|
292 |
-
wavs = wavs.mean(dim=1, keepdim=False)
|
293 |
-
B, T = wavs.shape
|
294 |
-
k = int(mix_p * B)
|
295 |
-
mixed_sources_idx = torch.randperm(B)[:k]
|
296 |
-
mixed_targets_idx = torch.randperm(B)[:k]
|
297 |
-
aug_wavs = snr_mix(
|
298 |
-
wavs[mixed_sources_idx],
|
299 |
-
wavs[mixed_targets_idx],
|
300 |
-
snr_low,
|
301 |
-
snr_high,
|
302 |
-
min_overlap,
|
303 |
-
)
|
304 |
-
# mixing textual descriptions in metadata
|
305 |
-
descriptions = [info.description for info in infos]
|
306 |
-
aug_infos = []
|
307 |
-
for i, j in zip(mixed_sources_idx, mixed_targets_idx):
|
308 |
-
text = mix_text(descriptions[i], descriptions[j])
|
309 |
-
m = replace(infos[i])
|
310 |
-
m.description = text
|
311 |
-
aug_infos.append(m)
|
312 |
-
|
313 |
-
# back to [B, C, T]
|
314 |
-
aug_wavs = aug_wavs.unsqueeze(1)
|
315 |
-
assert aug_wavs.shape[0] > 0, "Samples mixing returned empty batch."
|
316 |
-
assert aug_wavs.dim() == 3, f"Returned wav should be [B, C, T] but dim = {aug_wavs.dim()}"
|
317 |
-
assert aug_wavs.shape[0] == len(aug_infos), "Mismatch between number of wavs and infos in the batch"
|
318 |
-
|
319 |
-
return aug_wavs, aug_infos # [B, C, T]
|
320 |
-
else:
|
321 |
-
# randomly pick samples in the batch to match
|
322 |
-
# the batch size when performing audio mixing
|
323 |
-
B, C, T = wavs.shape
|
324 |
-
k = int(mix_p * B)
|
325 |
-
wav_idx = torch.randperm(B)[:k]
|
326 |
-
wavs = wavs[wav_idx]
|
327 |
-
infos = [infos[i] for i in wav_idx]
|
328 |
-
assert wavs.shape[0] == len(infos), "Mismatch between number of wavs and infos in the batch"
|
329 |
-
|
330 |
-
return wavs, infos # [B, C, T]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/data/zip.py
CHANGED
@@ -3,8 +3,6 @@
|
|
3 |
#
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
-
"""Utility for reading some info from inside a zip file.
|
7 |
-
"""
|
8 |
|
9 |
import typing
|
10 |
import zipfile
|
@@ -20,13 +18,13 @@ MODE = Literal['r', 'w', 'x', 'a']
|
|
20 |
|
21 |
@dataclass(order=True)
|
22 |
class PathInZip:
|
23 |
-
"""
|
24 |
|
25 |
Args:
|
26 |
-
path
|
27 |
Let's assume there is a zip file /some/location/foo.zip
|
28 |
and inside of it is a json file located at /data/file1.json,
|
29 |
-
Then we expect path = "/some/location/foo.zip:/data/file1.json"
|
30 |
"""
|
31 |
|
32 |
INFO_PATH_SEP = ':'
|
@@ -57,7 +55,7 @@ def set_zip_cache_size(max_size: int):
|
|
57 |
"""Sets the maximal LRU caching for zip file opening.
|
58 |
|
59 |
Args:
|
60 |
-
max_size
|
61 |
"""
|
62 |
global _cached_open_zip
|
63 |
_cached_open_zip = lru_cache(max_size)(_open_zip)
|
@@ -67,8 +65,8 @@ def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO:
|
|
67 |
"""Opens a file stored inside a zip and returns a file-like object.
|
68 |
|
69 |
Args:
|
70 |
-
path_in_zip
|
71 |
-
mode
|
72 |
Returns:
|
73 |
A file-like object for PathInZip.
|
74 |
"""
|
|
|
3 |
#
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
6 |
|
7 |
import typing
|
8 |
import zipfile
|
|
|
18 |
|
19 |
@dataclass(order=True)
|
20 |
class PathInZip:
|
21 |
+
"""Class for holding a path of file within a zip file.
|
22 |
|
23 |
Args:
|
24 |
+
path: The convention is <path_to_zip>:<relative_path_inside_zip>
|
25 |
Let's assume there is a zip file /some/location/foo.zip
|
26 |
and inside of it is a json file located at /data/file1.json,
|
27 |
+
Then we expect path = "/some/location/foo.zip:/data/file1.json"
|
28 |
"""
|
29 |
|
30 |
INFO_PATH_SEP = ':'
|
|
|
55 |
"""Sets the maximal LRU caching for zip file opening.
|
56 |
|
57 |
Args:
|
58 |
+
max_size: the maximal LRU cache.
|
59 |
"""
|
60 |
global _cached_open_zip
|
61 |
_cached_open_zip = lru_cache(max_size)(_open_zip)
|
|
|
65 |
"""Opens a file stored inside a zip and returns a file-like object.
|
66 |
|
67 |
Args:
|
68 |
+
path_in_zip: A PathInZip object representing the file to return a file-like object of.
|
69 |
+
mode: The mode in which to open the file with.
|
70 |
Returns:
|
71 |
A file-like object for PathInZip.
|
72 |
"""
|
audiocraft/environment.py
DELETED
@@ -1,176 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
"""
|
8 |
-
Provides cluster and tools configuration across clusters (slurm, dora, utilities).
|
9 |
-
"""
|
10 |
-
|
11 |
-
import logging
|
12 |
-
import os
|
13 |
-
from pathlib import Path
|
14 |
-
import re
|
15 |
-
import typing as tp
|
16 |
-
|
17 |
-
import omegaconf
|
18 |
-
|
19 |
-
from .utils.cluster import _guess_cluster_type
|
20 |
-
|
21 |
-
|
22 |
-
logger = logging.getLogger(__name__)
|
23 |
-
|
24 |
-
|
25 |
-
class AudioCraftEnvironment:
|
26 |
-
"""Environment configuration for teams and clusters.
|
27 |
-
|
28 |
-
AudioCraftEnvironment picks compute cluster settings (slurm, dora) from the current running environment
|
29 |
-
or declared variable and the loaded team configuration. Additionally, the AudioCraftEnvironment
|
30 |
-
provides pointers to a reference folder resolved automatically across clusters that is shared across team members,
|
31 |
-
allowing to share sigs or other files to run jobs. Finally, it provides dataset mappers to automatically
|
32 |
-
map dataset file paths to new locations across clusters, allowing to use the same manifest of files across cluters.
|
33 |
-
|
34 |
-
The cluster type is identified automatically and base configuration file is read from config/teams.yaml.
|
35 |
-
Use the following environment variables to specify the cluster, team or configuration:
|
36 |
-
|
37 |
-
AUDIOCRAFT_CLUSTER (optional): Cluster type to enforce. Useful if the cluster type
|
38 |
-
cannot be inferred automatically.
|
39 |
-
AUDIOCRAFT_CONFIG (optional): Path to yaml config holding the teams configuration.
|
40 |
-
If not set, configuration is read from config/teams.yaml.
|
41 |
-
AUDIOCRAFT_TEAM (optional): Name of the team. Recommended to set to your own team.
|
42 |
-
Cluster configuration are shared across teams to match compute allocation,
|
43 |
-
specify your cluster configuration in the configuration file under a key mapping
|
44 |
-
your team name.
|
45 |
-
"""
|
46 |
-
_instance = None
|
47 |
-
DEFAULT_TEAM = "default"
|
48 |
-
|
49 |
-
def __init__(self) -> None:
|
50 |
-
"""Loads configuration."""
|
51 |
-
self.team: str = os.getenv("AUDIOCRAFT_TEAM", self.DEFAULT_TEAM)
|
52 |
-
cluster_type = _guess_cluster_type()
|
53 |
-
cluster = os.getenv(
|
54 |
-
"AUDIOCRAFT_CLUSTER", cluster_type.value
|
55 |
-
)
|
56 |
-
logger.info("Detecting cluster type %s", cluster_type)
|
57 |
-
|
58 |
-
self.cluster: str = cluster
|
59 |
-
|
60 |
-
config_path = os.getenv(
|
61 |
-
"AUDIOCRAFT_CONFIG",
|
62 |
-
Path(__file__)
|
63 |
-
.parent.parent.joinpath("config/teams", self.team)
|
64 |
-
.with_suffix(".yaml"),
|
65 |
-
)
|
66 |
-
self.config = omegaconf.OmegaConf.load(config_path)
|
67 |
-
self._dataset_mappers = []
|
68 |
-
cluster_config = self._get_cluster_config()
|
69 |
-
if "dataset_mappers" in cluster_config:
|
70 |
-
for pattern, repl in cluster_config["dataset_mappers"].items():
|
71 |
-
regex = re.compile(pattern)
|
72 |
-
self._dataset_mappers.append((regex, repl))
|
73 |
-
|
74 |
-
def _get_cluster_config(self) -> omegaconf.DictConfig:
|
75 |
-
assert isinstance(self.config, omegaconf.DictConfig)
|
76 |
-
return self.config[self.cluster]
|
77 |
-
|
78 |
-
@classmethod
|
79 |
-
def instance(cls):
|
80 |
-
if cls._instance is None:
|
81 |
-
cls._instance = cls()
|
82 |
-
return cls._instance
|
83 |
-
|
84 |
-
@classmethod
|
85 |
-
def reset(cls):
|
86 |
-
"""Clears the environment and forces a reload on next invocation."""
|
87 |
-
cls._instance = None
|
88 |
-
|
89 |
-
@classmethod
|
90 |
-
def get_team(cls) -> str:
|
91 |
-
"""Gets the selected team as dictated by the AUDIOCRAFT_TEAM env var.
|
92 |
-
If not defined, defaults to "labs".
|
93 |
-
"""
|
94 |
-
return cls.instance().team
|
95 |
-
|
96 |
-
@classmethod
|
97 |
-
def get_cluster(cls) -> str:
|
98 |
-
"""Gets the detected cluster.
|
99 |
-
This value can be overridden by the AUDIOCRAFT_CLUSTER env var.
|
100 |
-
"""
|
101 |
-
return cls.instance().cluster
|
102 |
-
|
103 |
-
@classmethod
|
104 |
-
def get_dora_dir(cls) -> Path:
|
105 |
-
"""Gets the path to the dora directory for the current team and cluster.
|
106 |
-
Value is overridden by the AUDIOCRAFT_DORA_DIR env var.
|
107 |
-
"""
|
108 |
-
cluster_config = cls.instance()._get_cluster_config()
|
109 |
-
dora_dir = os.getenv("AUDIOCRAFT_DORA_DIR", cluster_config["dora_dir"])
|
110 |
-
logger.warning(f"Dora directory: {dora_dir}")
|
111 |
-
return Path(dora_dir)
|
112 |
-
|
113 |
-
@classmethod
|
114 |
-
def get_reference_dir(cls) -> Path:
|
115 |
-
"""Gets the path to the reference directory for the current team and cluster.
|
116 |
-
Value is overridden by the AUDIOCRAFT_REFERENCE_DIR env var.
|
117 |
-
"""
|
118 |
-
cluster_config = cls.instance()._get_cluster_config()
|
119 |
-
return Path(os.getenv("AUDIOCRAFT_REFERENCE_DIR", cluster_config["reference_dir"]))
|
120 |
-
|
121 |
-
@classmethod
|
122 |
-
def get_slurm_exclude(cls) -> tp.Optional[str]:
|
123 |
-
"""Get the list of nodes to exclude for that cluster."""
|
124 |
-
cluster_config = cls.instance()._get_cluster_config()
|
125 |
-
return cluster_config.get("slurm_exclude")
|
126 |
-
|
127 |
-
@classmethod
|
128 |
-
def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str]] = None) -> str:
|
129 |
-
"""Gets the requested partitions for the current team and cluster as a comma-separated string.
|
130 |
-
|
131 |
-
Args:
|
132 |
-
partition_types (list[str], optional): partition types to retrieve. Values must be
|
133 |
-
from ['global', 'team']. If not provided, the global partition is returned.
|
134 |
-
"""
|
135 |
-
if not partition_types:
|
136 |
-
partition_types = ["global"]
|
137 |
-
|
138 |
-
cluster_config = cls.instance()._get_cluster_config()
|
139 |
-
partitions = [
|
140 |
-
cluster_config["partitions"][partition_type]
|
141 |
-
for partition_type in partition_types
|
142 |
-
]
|
143 |
-
return ",".join(partitions)
|
144 |
-
|
145 |
-
@classmethod
|
146 |
-
def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path:
|
147 |
-
"""Converts reference placeholder in path with configured reference dir to resolve paths.
|
148 |
-
|
149 |
-
Args:
|
150 |
-
path (str or Path): Path to resolve.
|
151 |
-
Returns:
|
152 |
-
Path: Resolved path.
|
153 |
-
"""
|
154 |
-
path = str(path)
|
155 |
-
|
156 |
-
if path.startswith("//reference"):
|
157 |
-
reference_dir = cls.get_reference_dir()
|
158 |
-
logger.warn(f"Reference directory: {reference_dir}")
|
159 |
-
assert (
|
160 |
-
reference_dir.exists() and reference_dir.is_dir()
|
161 |
-
), f"Reference directory does not exist: {reference_dir}."
|
162 |
-
path = re.sub("^//reference", str(reference_dir), path)
|
163 |
-
|
164 |
-
return Path(path)
|
165 |
-
|
166 |
-
@classmethod
|
167 |
-
def apply_dataset_mappers(cls, path: str) -> str:
|
168 |
-
"""Applies dataset mapping regex rules as defined in the configuration.
|
169 |
-
If no rules are defined, the path is returned as-is.
|
170 |
-
"""
|
171 |
-
instance = cls.instance()
|
172 |
-
|
173 |
-
for pattern, repl in instance._dataset_mappers:
|
174 |
-
path = pattern.sub(repl, path)
|
175 |
-
|
176 |
-
return path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/grids/__init__.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
"""Dora Grids."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/grids/_base_explorers.py
DELETED
@@ -1,80 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
from abc import ABC, abstractmethod
|
8 |
-
import time
|
9 |
-
import typing as tp
|
10 |
-
from dora import Explorer
|
11 |
-
import treetable as tt
|
12 |
-
|
13 |
-
|
14 |
-
def get_sheep_ping(sheep) -> tp.Optional[str]:
|
15 |
-
"""Return the amount of time since the Sheep made some update
|
16 |
-
to its log. Returns a str using the relevant time unit."""
|
17 |
-
ping = None
|
18 |
-
if sheep.log is not None and sheep.log.exists():
|
19 |
-
delta = time.time() - sheep.log.stat().st_mtime
|
20 |
-
if delta > 3600 * 24:
|
21 |
-
ping = f'{delta / (3600 * 24):.1f}d'
|
22 |
-
elif delta > 3600:
|
23 |
-
ping = f'{delta / (3600):.1f}h'
|
24 |
-
elif delta > 60:
|
25 |
-
ping = f'{delta / 60:.1f}m'
|
26 |
-
else:
|
27 |
-
ping = f'{delta:.1f}s'
|
28 |
-
return ping
|
29 |
-
|
30 |
-
|
31 |
-
class BaseExplorer(ABC, Explorer):
|
32 |
-
"""Base explorer for AudioCraft grids.
|
33 |
-
|
34 |
-
All task specific solvers are expected to implement the `get_grid_metrics`
|
35 |
-
method to specify logic about metrics to display for a given task.
|
36 |
-
|
37 |
-
If additional stages are used, the child explorer must define how to handle
|
38 |
-
these new stages in the `process_history` and `process_sheep` methods.
|
39 |
-
"""
|
40 |
-
def stages(self):
|
41 |
-
return ["train", "valid", "evaluate"]
|
42 |
-
|
43 |
-
def get_grid_meta(self):
|
44 |
-
"""Returns the list of Meta information to display for each XP/job.
|
45 |
-
"""
|
46 |
-
return [
|
47 |
-
tt.leaf("index", align=">"),
|
48 |
-
tt.leaf("name", wrap=140),
|
49 |
-
tt.leaf("state"),
|
50 |
-
tt.leaf("sig", align=">"),
|
51 |
-
tt.leaf("sid", align="<"),
|
52 |
-
]
|
53 |
-
|
54 |
-
@abstractmethod
|
55 |
-
def get_grid_metrics(self):
|
56 |
-
"""Return the metrics that should be displayed in the tracking table.
|
57 |
-
"""
|
58 |
-
...
|
59 |
-
|
60 |
-
def process_sheep(self, sheep, history):
|
61 |
-
train = {
|
62 |
-
"epoch": len(history),
|
63 |
-
}
|
64 |
-
parts = {"train": train}
|
65 |
-
for metrics in history:
|
66 |
-
for key, sub in metrics.items():
|
67 |
-
part = parts.get(key, {})
|
68 |
-
if 'duration' in sub:
|
69 |
-
# Convert to minutes for readability.
|
70 |
-
sub['duration'] = sub['duration'] / 60.
|
71 |
-
part.update(sub)
|
72 |
-
parts[key] = part
|
73 |
-
ping = get_sheep_ping(sheep)
|
74 |
-
if ping is not None:
|
75 |
-
for name in self.stages():
|
76 |
-
if name not in parts:
|
77 |
-
parts[name] = {}
|
78 |
-
# Add the ping to each part for convenience.
|
79 |
-
parts[name]['ping'] = ping
|
80 |
-
return parts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/grids/audiogen/__init__.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
"""AudioGen grids."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/grids/audiogen/audiogen_base_16khz.py
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
from ..musicgen._explorers import LMExplorer
|
8 |
-
from ...environment import AudioCraftEnvironment
|
9 |
-
|
10 |
-
|
11 |
-
@LMExplorer
|
12 |
-
def explorer(launcher):
|
13 |
-
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
14 |
-
launcher.slurm_(gpus=64, partition=partitions)
|
15 |
-
launcher.bind_(solver='audiogen/audiogen_base_16khz')
|
16 |
-
# replace this by the desired environmental sound dataset
|
17 |
-
launcher.bind_(dset='internal/sounds_16khz')
|
18 |
-
|
19 |
-
fsdp = {'autocast': False, 'fsdp.use': True}
|
20 |
-
medium = {'model/lm/model_scale': 'medium'}
|
21 |
-
|
22 |
-
launcher.bind_(fsdp)
|
23 |
-
launcher(medium)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py
DELETED
@@ -1,68 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
"""
|
8 |
-
Evaluation with objective metrics for the pretrained AudioGen models.
|
9 |
-
This grid takes signature from the training grid and runs evaluation-only stage.
|
10 |
-
|
11 |
-
When running the grid for the first time, please use:
|
12 |
-
REGEN=1 dora grid audiogen.audiogen_pretrained_16khz_eval
|
13 |
-
and re-use the REGEN=1 option when the grid is changed to force regenerating it.
|
14 |
-
|
15 |
-
Note that you need the proper metrics external libraries setup to use all
|
16 |
-
the objective metrics activated in this grid. Refer to the README for more information.
|
17 |
-
"""
|
18 |
-
|
19 |
-
import os
|
20 |
-
|
21 |
-
from ..musicgen._explorers import GenerationEvalExplorer
|
22 |
-
from ...environment import AudioCraftEnvironment
|
23 |
-
from ... import train
|
24 |
-
|
25 |
-
|
26 |
-
def eval(launcher, batch_size: int = 32):
|
27 |
-
opts = {
|
28 |
-
'dset': 'audio/audiocaps_16khz',
|
29 |
-
'solver/audiogen/evaluation': 'objective_eval',
|
30 |
-
'execute_only': 'evaluate',
|
31 |
-
'+dataset.evaluate.batch_size': batch_size,
|
32 |
-
'+metrics.fad.tf.batch_size': 32,
|
33 |
-
}
|
34 |
-
# binary for FAD computation: replace this path with your own path
|
35 |
-
metrics_opts = {
|
36 |
-
'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research'
|
37 |
-
}
|
38 |
-
opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.}
|
39 |
-
opt2 = {'transformer_lm.two_step_cfg': True}
|
40 |
-
|
41 |
-
sub = launcher.bind(opts)
|
42 |
-
sub.bind_(metrics_opts)
|
43 |
-
|
44 |
-
# base objective metrics
|
45 |
-
sub(opt1, opt2)
|
46 |
-
|
47 |
-
|
48 |
-
@GenerationEvalExplorer
|
49 |
-
def explorer(launcher):
|
50 |
-
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
51 |
-
launcher.slurm_(gpus=4, partition=partitions)
|
52 |
-
|
53 |
-
if 'REGEN' not in os.environ:
|
54 |
-
folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1]
|
55 |
-
with launcher.job_array():
|
56 |
-
for sig in folder.iterdir():
|
57 |
-
if not sig.is_symlink():
|
58 |
-
continue
|
59 |
-
xp = train.main.get_xp_from_sig(sig.name)
|
60 |
-
launcher(xp.argv)
|
61 |
-
return
|
62 |
-
|
63 |
-
audiogen_base = launcher.bind(solver="audiogen/audiogen_base_16khz")
|
64 |
-
audiogen_base.bind_({'autocast': False, 'fsdp.use': True})
|
65 |
-
|
66 |
-
audiogen_base_medium = audiogen_base.bind({'continue_from': '//pretrained/facebook/audiogen-medium'})
|
67 |
-
audiogen_base_medium.bind_({'model/lm/model_scale': 'medium'})
|
68 |
-
eval(audiogen_base_medium, batch_size=128)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/grids/compression/__init__.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
"""EnCodec grids."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/grids/compression/_explorers.py
DELETED
@@ -1,55 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import treetable as tt
|
8 |
-
|
9 |
-
from .._base_explorers import BaseExplorer
|
10 |
-
|
11 |
-
|
12 |
-
class CompressionExplorer(BaseExplorer):
|
13 |
-
eval_metrics = ["sisnr", "visqol"]
|
14 |
-
|
15 |
-
def stages(self):
|
16 |
-
return ["train", "valid", "evaluate"]
|
17 |
-
|
18 |
-
def get_grid_meta(self):
|
19 |
-
"""Returns the list of Meta information to display for each XP/job.
|
20 |
-
"""
|
21 |
-
return [
|
22 |
-
tt.leaf("index", align=">"),
|
23 |
-
tt.leaf("name", wrap=140),
|
24 |
-
tt.leaf("state"),
|
25 |
-
tt.leaf("sig", align=">"),
|
26 |
-
]
|
27 |
-
|
28 |
-
def get_grid_metrics(self):
|
29 |
-
"""Return the metrics that should be displayed in the tracking table.
|
30 |
-
"""
|
31 |
-
return [
|
32 |
-
tt.group(
|
33 |
-
"train",
|
34 |
-
[
|
35 |
-
tt.leaf("epoch"),
|
36 |
-
tt.leaf("bandwidth", ".2f"),
|
37 |
-
tt.leaf("adv", ".4f"),
|
38 |
-
tt.leaf("d_loss", ".4f"),
|
39 |
-
],
|
40 |
-
align=">",
|
41 |
-
),
|
42 |
-
tt.group(
|
43 |
-
"valid",
|
44 |
-
[
|
45 |
-
tt.leaf("bandwidth", ".2f"),
|
46 |
-
tt.leaf("adv", ".4f"),
|
47 |
-
tt.leaf("msspec", ".4f"),
|
48 |
-
tt.leaf("sisnr", ".2f"),
|
49 |
-
],
|
50 |
-
align=">",
|
51 |
-
),
|
52 |
-
tt.group(
|
53 |
-
"evaluate", [tt.leaf(name, ".3f") for name in self.eval_metrics], align=">"
|
54 |
-
),
|
55 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/grids/compression/debug.py
DELETED
@@ -1,31 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
"""
|
8 |
-
Grid search file, simply list all the exp you want in `explorer`.
|
9 |
-
Any new exp added there will be scheduled.
|
10 |
-
You can cancel and experiment by commenting its line.
|
11 |
-
|
12 |
-
This grid is a minimal example for debugging compression task
|
13 |
-
and how to override parameters directly in a grid.
|
14 |
-
Learn more about dora grids: https://github.com/facebookresearch/dora
|
15 |
-
"""
|
16 |
-
|
17 |
-
from ._explorers import CompressionExplorer
|
18 |
-
from ...environment import AudioCraftEnvironment
|
19 |
-
|
20 |
-
|
21 |
-
@CompressionExplorer
|
22 |
-
def explorer(launcher):
|
23 |
-
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
24 |
-
launcher.slurm_(gpus=2, partition=partitions)
|
25 |
-
launcher.bind_(solver='compression/debug')
|
26 |
-
|
27 |
-
with launcher.job_array():
|
28 |
-
# base debug task using config from solver=compression/debug
|
29 |
-
launcher()
|
30 |
-
# we can override parameters in the grid to launch additional xps
|
31 |
-
launcher({'rvq.bins': 2048, 'rvq.n_q': 4})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/grids/compression/encodec_audiogen_16khz.py
DELETED
@@ -1,29 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
"""
|
8 |
-
Grid search file, simply list all the exp you want in `explorer`.
|
9 |
-
Any new exp added there will be scheduled.
|
10 |
-
You can cancel and experiment by commenting its line.
|
11 |
-
|
12 |
-
This grid shows how to train the new AudioGen EnCodec model at 16 kHz.
|
13 |
-
"""
|
14 |
-
|
15 |
-
from ._explorers import CompressionExplorer
|
16 |
-
from ...environment import AudioCraftEnvironment
|
17 |
-
|
18 |
-
|
19 |
-
@CompressionExplorer
|
20 |
-
def explorer(launcher):
|
21 |
-
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
22 |
-
launcher.slurm_(gpus=8, partition=partitions)
|
23 |
-
# use configuration for AudioGen's EnCodec model trained on monophonic audio sampled at 16 kHz
|
24 |
-
# AudioGen's EnCodec is trained with a total stride of 320 leading to a frame rate of 50 hz
|
25 |
-
launcher.bind_(solver='compression/encodec_audiogen_16khz')
|
26 |
-
# replace this by the desired sound dataset
|
27 |
-
launcher.bind_(dset='internal/sounds_16khz')
|
28 |
-
# launch xp
|
29 |
-
launcher()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/grids/compression/encodec_base_24khz.py
DELETED
@@ -1,28 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
"""
|
8 |
-
Grid search file, simply list all the exp you want in `explorer`.
|
9 |
-
Any new exp added there will be scheduled.
|
10 |
-
You can cancel and experiment by commenting its line.
|
11 |
-
|
12 |
-
This grid shows how to train a base causal EnCodec model at 24 kHz.
|
13 |
-
"""
|
14 |
-
|
15 |
-
from ._explorers import CompressionExplorer
|
16 |
-
from ...environment import AudioCraftEnvironment
|
17 |
-
|
18 |
-
|
19 |
-
@CompressionExplorer
|
20 |
-
def explorer(launcher):
|
21 |
-
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
22 |
-
launcher.slurm_(gpus=8, partition=partitions)
|
23 |
-
# base causal EnCodec trained on monophonic audio sampled at 24 kHz
|
24 |
-
launcher.bind_(solver='compression/encodec_base_24khz')
|
25 |
-
# replace this by the desired dataset
|
26 |
-
launcher.bind_(dset='audio/example')
|
27 |
-
# launch xp
|
28 |
-
launcher()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/grids/compression/encodec_musicgen_32khz.py
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
"""
|
8 |
-
Grid search file, simply list all the exp you want in `explorer`.
|
9 |
-
Any new exp added there will be scheduled.
|
10 |
-
You can cancel and experiment by commenting its line.
|
11 |
-
|
12 |
-
This grid shows how to train a MusicGen EnCodec model at 32 kHz.
|
13 |
-
"""
|
14 |
-
|
15 |
-
from ._explorers import CompressionExplorer
|
16 |
-
from ...environment import AudioCraftEnvironment
|
17 |
-
|
18 |
-
|
19 |
-
@CompressionExplorer
|
20 |
-
def explorer(launcher):
|
21 |
-
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
22 |
-
launcher.slurm_(gpus=8, partition=partitions)
|
23 |
-
# use configuration for MusicGen's EnCodec model trained on monophonic audio sampled at 32 kHz
|
24 |
-
# MusicGen's EnCodec is trained with a total stride of 640 leading to a frame rate of 50 hz
|
25 |
-
launcher.bind_(solver='compression/encodec_musicgen_32khz')
|
26 |
-
# replace this by the desired music dataset
|
27 |
-
launcher.bind_(dset='internal/music_400k_32khz')
|
28 |
-
# launch xp
|
29 |
-
launcher()
|
30 |
-
launcher({
|
31 |
-
'metrics.visqol.bin': '/data/home/jadecopet/local/usr/opt/visqol',
|
32 |
-
'label': 'visqol',
|
33 |
-
'evaluate.metrics.visqol': True
|
34 |
-
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/grids/diffusion/4_bands_base_32khz.py
DELETED
@@ -1,27 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
"""
|
8 |
-
Training of the 4 diffusion models described in
|
9 |
-
"From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion"
|
10 |
-
(paper link).
|
11 |
-
"""
|
12 |
-
|
13 |
-
from ._explorers import DiffusionExplorer
|
14 |
-
|
15 |
-
|
16 |
-
@DiffusionExplorer
|
17 |
-
def explorer(launcher):
|
18 |
-
launcher.slurm_(gpus=4, partition='learnfair')
|
19 |
-
|
20 |
-
launcher.bind_({'solver': 'diffusion/default',
|
21 |
-
'dset': 'internal/music_10k_32khz'})
|
22 |
-
|
23 |
-
with launcher.job_array():
|
24 |
-
launcher({'filter.use': True, 'filter.idx_band': 0, "processor.use": False, 'processor.power_std': 0.4})
|
25 |
-
launcher({'filter.use': True, 'filter.idx_band': 1, "processor.use": False, 'processor.power_std': 0.4})
|
26 |
-
launcher({'filter.use': True, 'filter.idx_band': 2, "processor.use": True, 'processor.power_std': 0.4})
|
27 |
-
launcher({'filter.use': True, 'filter.idx_band': 3, "processor.use": True, 'processor.power_std': 0.75})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/grids/diffusion/__init__.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
"""Diffusion grids."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/grids/diffusion/_explorers.py
DELETED
@@ -1,66 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import treetable as tt
|
8 |
-
|
9 |
-
from .._base_explorers import BaseExplorer
|
10 |
-
|
11 |
-
|
12 |
-
class DiffusionExplorer(BaseExplorer):
|
13 |
-
eval_metrics = ["sisnr", "visqol"]
|
14 |
-
|
15 |
-
def stages(self):
|
16 |
-
return ["train", "valid", "valid_ema", "evaluate", "evaluate_ema"]
|
17 |
-
|
18 |
-
def get_grid_meta(self):
|
19 |
-
"""Returns the list of Meta information to display for each XP/job.
|
20 |
-
"""
|
21 |
-
return [
|
22 |
-
tt.leaf("index", align=">"),
|
23 |
-
tt.leaf("name", wrap=140),
|
24 |
-
tt.leaf("state"),
|
25 |
-
tt.leaf("sig", align=">"),
|
26 |
-
]
|
27 |
-
|
28 |
-
def get_grid_metrics(self):
|
29 |
-
"""Return the metrics that should be displayed in the tracking table.
|
30 |
-
"""
|
31 |
-
return [
|
32 |
-
tt.group(
|
33 |
-
"train",
|
34 |
-
[
|
35 |
-
tt.leaf("epoch"),
|
36 |
-
tt.leaf("loss", ".3%"),
|
37 |
-
],
|
38 |
-
align=">",
|
39 |
-
),
|
40 |
-
tt.group(
|
41 |
-
"valid",
|
42 |
-
[
|
43 |
-
tt.leaf("loss", ".3%"),
|
44 |
-
# tt.leaf("loss_0", ".3%"),
|
45 |
-
],
|
46 |
-
align=">",
|
47 |
-
),
|
48 |
-
tt.group(
|
49 |
-
"valid_ema",
|
50 |
-
[
|
51 |
-
tt.leaf("loss", ".3%"),
|
52 |
-
# tt.leaf("loss_0", ".3%"),
|
53 |
-
],
|
54 |
-
align=">",
|
55 |
-
),
|
56 |
-
tt.group(
|
57 |
-
"evaluate", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"),
|
58 |
-
tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"),
|
59 |
-
tt.leaf("rvm_3", ".4f"), ], align=">"
|
60 |
-
),
|
61 |
-
tt.group(
|
62 |
-
"evaluate_ema", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"),
|
63 |
-
tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"),
|
64 |
-
tt.leaf("rvm_3", ".4f")], align=">"
|
65 |
-
),
|
66 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/grids/musicgen/__init__.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
"""MusicGen grids."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/grids/musicgen/_explorers.py
DELETED
@@ -1,93 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import typing as tp
|
8 |
-
|
9 |
-
import treetable as tt
|
10 |
-
|
11 |
-
from .._base_explorers import BaseExplorer
|
12 |
-
|
13 |
-
|
14 |
-
class LMExplorer(BaseExplorer):
|
15 |
-
eval_metrics: tp.List[str] = []
|
16 |
-
|
17 |
-
def stages(self) -> tp.List[str]:
|
18 |
-
return ['train', 'valid']
|
19 |
-
|
20 |
-
def get_grid_metrics(self):
|
21 |
-
"""Return the metrics that should be displayed in the tracking table."""
|
22 |
-
return [
|
23 |
-
tt.group(
|
24 |
-
'train',
|
25 |
-
[
|
26 |
-
tt.leaf('epoch'),
|
27 |
-
tt.leaf('duration', '.1f'), # duration in minutes
|
28 |
-
tt.leaf('ping'),
|
29 |
-
tt.leaf('ce', '.4f'), # cross entropy
|
30 |
-
tt.leaf("ppl", '.3f'), # perplexity
|
31 |
-
],
|
32 |
-
align='>',
|
33 |
-
),
|
34 |
-
tt.group(
|
35 |
-
'valid',
|
36 |
-
[
|
37 |
-
tt.leaf('ce', '.4f'),
|
38 |
-
tt.leaf('ppl', '.3f'),
|
39 |
-
tt.leaf('best_ppl', '.3f'),
|
40 |
-
],
|
41 |
-
align='>',
|
42 |
-
),
|
43 |
-
]
|
44 |
-
|
45 |
-
def process_sheep(self, sheep, history):
|
46 |
-
parts = super().process_sheep(sheep, history)
|
47 |
-
|
48 |
-
track_by = {'ppl': 'lower'} # values should be in ['lower', 'higher']
|
49 |
-
best_metrics = {k: (1 if v == 'lower' else -1) * float('inf') for k, v in track_by.items()}
|
50 |
-
|
51 |
-
def comparator(mode, a, b):
|
52 |
-
return a < b if mode == 'lower' else a > b
|
53 |
-
|
54 |
-
for metrics in history:
|
55 |
-
for key, sub in metrics.items():
|
56 |
-
for metric in track_by:
|
57 |
-
# for the validation set, keep track of best metrics (ppl in this example)
|
58 |
-
# this is so we can conveniently compare metrics between runs in the grid
|
59 |
-
if key == 'valid' and metric in sub and comparator(
|
60 |
-
track_by[metric], sub[metric], best_metrics[metric]
|
61 |
-
):
|
62 |
-
best_metrics[metric] = sub[metric]
|
63 |
-
|
64 |
-
if 'valid' in parts:
|
65 |
-
parts['valid'].update({f'best_{k}': v for k, v in best_metrics.items()})
|
66 |
-
return parts
|
67 |
-
|
68 |
-
|
69 |
-
class GenerationEvalExplorer(BaseExplorer):
|
70 |
-
eval_metrics: tp.List[str] = []
|
71 |
-
|
72 |
-
def stages(self) -> tp.List[str]:
|
73 |
-
return ['evaluate']
|
74 |
-
|
75 |
-
def get_grid_metrics(self):
|
76 |
-
"""Return the metrics that should be displayed in the tracking table."""
|
77 |
-
return [
|
78 |
-
tt.group(
|
79 |
-
'evaluate',
|
80 |
-
[
|
81 |
-
tt.leaf('epoch', '.3f'),
|
82 |
-
tt.leaf('duration', '.1f'),
|
83 |
-
tt.leaf('ping'),
|
84 |
-
tt.leaf('ce', '.4f'),
|
85 |
-
tt.leaf('ppl', '.3f'),
|
86 |
-
tt.leaf('fad', '.3f'),
|
87 |
-
tt.leaf('kld', '.3f'),
|
88 |
-
tt.leaf('text_consistency', '.3f'),
|
89 |
-
tt.leaf('chroma_cosine', '.3f'),
|
90 |
-
],
|
91 |
-
align='>',
|
92 |
-
),
|
93 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/grids/musicgen/musicgen_base_32khz.py
DELETED
@@ -1,43 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
from ._explorers import LMExplorer
|
8 |
-
from ...environment import AudioCraftEnvironment
|
9 |
-
|
10 |
-
|
11 |
-
@LMExplorer
|
12 |
-
def explorer(launcher):
|
13 |
-
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
14 |
-
launcher.slurm_(gpus=32, partition=partitions)
|
15 |
-
launcher.bind_(solver='musicgen/musicgen_base_32khz')
|
16 |
-
# replace this by the desired music dataset
|
17 |
-
launcher.bind_(dset='internal/music_400k_32khz')
|
18 |
-
|
19 |
-
fsdp = {'autocast': False, 'fsdp.use': True}
|
20 |
-
medium = {'model/lm/model_scale': 'medium'}
|
21 |
-
large = {'model/lm/model_scale': 'large'}
|
22 |
-
|
23 |
-
cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
|
24 |
-
wd_low = {'conditioners.description.t5.word_dropout': 0.2}
|
25 |
-
|
26 |
-
adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
|
27 |
-
|
28 |
-
launcher.bind_(fsdp)
|
29 |
-
|
30 |
-
launcher.slurm_(gpus=32).bind_(label='32gpus')
|
31 |
-
with launcher.job_array():
|
32 |
-
sub = launcher.bind()
|
33 |
-
sub()
|
34 |
-
|
35 |
-
launcher.slurm_(gpus=64).bind_(label='64gpus')
|
36 |
-
with launcher.job_array():
|
37 |
-
sub = launcher.bind()
|
38 |
-
sub(medium, adam)
|
39 |
-
|
40 |
-
launcher.slurm_(gpus=96).bind_(label='96gpus')
|
41 |
-
with launcher.job_array():
|
42 |
-
sub = launcher.bind()
|
43 |
-
sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/grids/musicgen/musicgen_base_cached_32khz.py
DELETED
@@ -1,67 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
from ._explorers import LMExplorer
|
8 |
-
from ...environment import AudioCraftEnvironment
|
9 |
-
|
10 |
-
|
11 |
-
@LMExplorer
|
12 |
-
def explorer(launcher):
|
13 |
-
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
|
14 |
-
launcher.slurm_(gpus=32, partition=partitions)
|
15 |
-
launcher.bind_(solver='musicgen/musicgen_base_32khz')
|
16 |
-
# replace this by the desired music dataset
|
17 |
-
launcher.bind_(dset='internal/music_400k_32khz')
|
18 |
-
|
19 |
-
fsdp = {'autocast': False, 'fsdp.use': True}
|
20 |
-
medium = {'model/lm/model_scale': 'medium'}
|
21 |
-
large = {'model/lm/model_scale': 'large'}
|
22 |
-
|
23 |
-
cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
|
24 |
-
wd_low = {'conditioners.description.t5.word_dropout': 0.2}
|
25 |
-
|
26 |
-
adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
|
27 |
-
|
28 |
-
# BEGINNING OF CACHE WRITING JOBS.
|
29 |
-
cache_write = {
|
30 |
-
'cache.path': '/fsx-codegen/defossez/cache/interleave_stereo_nv_32k',
|
31 |
-
'cache.write': True,
|
32 |
-
'generate.every': 500,
|
33 |
-
'evaluate.every': 500,
|
34 |
-
'logging.log_updates': 50,
|
35 |
-
}
|
36 |
-
|
37 |
-
cache_sub = launcher.bind({'model/lm/model_scale': 'xsmall', 'conditioner': 'none'})
|
38 |
-
cache_sub.bind_({'deadlock.use': True})
|
39 |
-
cache_sub.slurm_(gpus=8)
|
40 |
-
with launcher.job_array():
|
41 |
-
num_shards = 10 # total number of jobs running in parallel.
|
42 |
-
for shard in range(0, num_shards):
|
43 |
-
launcher(cache_write, {'cache.write_num_shards': num_shards, 'cache.write_shard': shard})
|
44 |
-
|
45 |
-
# REMOVE THE FOLLOWING RETURN STATEMENT ONCE THE ABOVE JOBS ARE DONE,
|
46 |
-
# OR SUFFICIENTLY AHEAD.
|
47 |
-
return
|
48 |
-
|
49 |
-
cache = {
|
50 |
-
'cache.path': '/fsx-codegen/defossez/cache/interleave_stereo_nv_32k',
|
51 |
-
}
|
52 |
-
launcher.bind_(fsdp, cache)
|
53 |
-
|
54 |
-
launcher.slurm_(gpus=32).bind_(label='32gpus')
|
55 |
-
with launcher.job_array():
|
56 |
-
sub = launcher.bind()
|
57 |
-
sub()
|
58 |
-
|
59 |
-
launcher.slurm_(gpus=64).bind_(label='64gpus')
|
60 |
-
with launcher.job_array():
|
61 |
-
sub = launcher.bind()
|
62 |
-
sub(medium, adam)
|
63 |
-
|
64 |
-
launcher.slurm_(gpus=96).bind_(label='96gpus')
|
65 |
-
with launcher.job_array():
|
66 |
-
sub = launcher.bind()
|
67 |
-
sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|