sonalkum commited on
Commit
dd843b9
·
1 Parent(s): 6030e58

synthio-stable

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. stable/LICENSE +0 -21
  2. stable/LICENSES/LICENSE_ADP.txt +0 -21
  3. stable/LICENSES/LICENSE_AURALOSS.txt +0 -201
  4. stable/LICENSES/LICENSE_DESCRIPT.txt +0 -21
  5. stable/LICENSES/LICENSE_META.txt +0 -21
  6. stable/LICENSES/LICENSE_NVIDIA.txt +0 -21
  7. stable/LICENSES/LICENSE_XTRANSFORMERS.txt +0 -21
  8. stable/README.md +0 -157
  9. stable/build/lib/stable_audio_tools/__init__.py +0 -2
  10. stable/build/lib/stable_audio_tools/data/__init__.py +0 -0
  11. stable/build/lib/stable_audio_tools/data/dataset.py +0 -654
  12. stable/build/lib/stable_audio_tools/data/utils.py +0 -96
  13. stable/build/lib/stable_audio_tools/inference/__init__.py +0 -0
  14. stable/build/lib/stable_audio_tools/inference/generation.py +0 -274
  15. stable/build/lib/stable_audio_tools/inference/sampling.py +0 -232
  16. stable/build/lib/stable_audio_tools/inference/utils.py +0 -35
  17. stable/build/lib/stable_audio_tools/interface/__init__.py +0 -0
  18. stable/build/lib/stable_audio_tools/interface/gradio.py +0 -700
  19. stable/build/lib/stable_audio_tools/models/__init__.py +0 -1
  20. stable/build/lib/stable_audio_tools/models/adp.py +0 -1588
  21. stable/build/lib/stable_audio_tools/models/autoencoders.py +0 -794
  22. stable/build/lib/stable_audio_tools/models/blocks.py +0 -339
  23. stable/build/lib/stable_audio_tools/models/bottleneck.py +0 -326
  24. stable/build/lib/stable_audio_tools/models/codebook_patterns.py +0 -545
  25. stable/build/lib/stable_audio_tools/models/conditioners.py +0 -561
  26. stable/build/lib/stable_audio_tools/models/diffusion.py +0 -701
  27. stable/build/lib/stable_audio_tools/models/diffusion_prior.py +0 -79
  28. stable/build/lib/stable_audio_tools/models/discriminators.py +0 -546
  29. stable/build/lib/stable_audio_tools/models/dit.py +0 -379
  30. stable/build/lib/stable_audio_tools/models/factory.py +0 -153
  31. stable/build/lib/stable_audio_tools/models/lm.py +0 -541
  32. stable/build/lib/stable_audio_tools/models/lm_backbone.py +0 -159
  33. stable/build/lib/stable_audio_tools/models/local_attention.py +0 -278
  34. stable/build/lib/stable_audio_tools/models/pqmf.py +0 -393
  35. stable/build/lib/stable_audio_tools/models/pretrained.py +0 -25
  36. stable/build/lib/stable_audio_tools/models/pretransforms.py +0 -258
  37. stable/build/lib/stable_audio_tools/models/transformer.py +0 -805
  38. stable/build/lib/stable_audio_tools/models/utils.py +0 -89
  39. stable/build/lib/stable_audio_tools/models/wavelets.py +0 -82
  40. stable/build/lib/stable_audio_tools/training/__init__.py +0 -1
  41. stable/build/lib/stable_audio_tools/training/autoencoders.py +0 -477
  42. stable/build/lib/stable_audio_tools/training/diffusion.py +0 -1505
  43. stable/build/lib/stable_audio_tools/training/factory.py +0 -240
  44. stable/build/lib/stable_audio_tools/training/lm.py +0 -267
  45. stable/build/lib/stable_audio_tools/training/losses/__init__.py +0 -1
  46. stable/build/lib/stable_audio_tools/training/losses/auraloss.py +0 -607
  47. stable/build/lib/stable_audio_tools/training/losses/losses.py +0 -101
  48. stable/build/lib/stable_audio_tools/training/utils.py +0 -111
  49. stable/config_adapter.json +0 -124
  50. stable/convert_json.py +0 -44
stable/LICENSE DELETED
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) 2023 Stability AI
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/LICENSES/LICENSE_ADP.txt DELETED
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) 2022 archinet.ai
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/LICENSES/LICENSE_AURALOSS.txt DELETED
@@ -1,201 +0,0 @@
1
- Apache License
2
- Version 2.0, January 2004
3
- http://www.apache.org/licenses/
4
-
5
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
-
7
- 1. Definitions.
8
-
9
- "License" shall mean the terms and conditions for use, reproduction,
10
- and distribution as defined by Sections 1 through 9 of this document.
11
-
12
- "Licensor" shall mean the copyright owner or entity authorized by
13
- the copyright owner that is granting the License.
14
-
15
- "Legal Entity" shall mean the union of the acting entity and all
16
- other entities that control, are controlled by, or are under common
17
- control with that entity. For the purposes of this definition,
18
- "control" means (i) the power, direct or indirect, to cause the
19
- direction or management of such entity, whether by contract or
20
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
- outstanding shares, or (iii) beneficial ownership of such entity.
22
-
23
- "You" (or "Your") shall mean an individual or Legal Entity
24
- exercising permissions granted by this License.
25
-
26
- "Source" form shall mean the preferred form for making modifications,
27
- including but not limited to software source code, documentation
28
- source, and configuration files.
29
-
30
- "Object" form shall mean any form resulting from mechanical
31
- transformation or translation of a Source form, including but
32
- not limited to compiled object code, generated documentation,
33
- and conversions to other media types.
34
-
35
- "Work" shall mean the work of authorship, whether in Source or
36
- Object form, made available under the License, as indicated by a
37
- copyright notice that is included in or attached to the work
38
- (an example is provided in the Appendix below).
39
-
40
- "Derivative Works" shall mean any work, whether in Source or Object
41
- form, that is based on (or derived from) the Work and for which the
42
- editorial revisions, annotations, elaborations, or other modifications
43
- represent, as a whole, an original work of authorship. For the purposes
44
- of this License, Derivative Works shall not include works that remain
45
- separable from, or merely link (or bind by name) to the interfaces of,
46
- the Work and Derivative Works thereof.
47
-
48
- "Contribution" shall mean any work of authorship, including
49
- the original version of the Work and any modifications or additions
50
- to that Work or Derivative Works thereof, that is intentionally
51
- submitted to Licensor for inclusion in the Work by the copyright owner
52
- or by an individual or Legal Entity authorized to submit on behalf of
53
- the copyright owner. For the purposes of this definition, "submitted"
54
- means any form of electronic, verbal, or written communication sent
55
- to the Licensor or its representatives, including but not limited to
56
- communication on electronic mailing lists, source code control systems,
57
- and issue tracking systems that are managed by, or on behalf of, the
58
- Licensor for the purpose of discussing and improving the Work, but
59
- excluding communication that is conspicuously marked or otherwise
60
- designated in writing by the copyright owner as "Not a Contribution."
61
-
62
- "Contributor" shall mean Licensor and any individual or Legal Entity
63
- on behalf of whom a Contribution has been received by Licensor and
64
- subsequently incorporated within the Work.
65
-
66
- 2. Grant of Copyright License. Subject to the terms and conditions of
67
- this License, each Contributor hereby grants to You a perpetual,
68
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
- copyright license to reproduce, prepare Derivative Works of,
70
- publicly display, publicly perform, sublicense, and distribute the
71
- Work and such Derivative Works in Source or Object form.
72
-
73
- 3. Grant of Patent License. Subject to the terms and conditions of
74
- this License, each Contributor hereby grants to You a perpetual,
75
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
- (except as stated in this section) patent license to make, have made,
77
- use, offer to sell, sell, import, and otherwise transfer the Work,
78
- where such license applies only to those patent claims licensable
79
- by such Contributor that are necessarily infringed by their
80
- Contribution(s) alone or by combination of their Contribution(s)
81
- with the Work to which such Contribution(s) was submitted. If You
82
- institute patent litigation against any entity (including a
83
- cross-claim or counterclaim in a lawsuit) alleging that the Work
84
- or a Contribution incorporated within the Work constitutes direct
85
- or contributory patent infringement, then any patent licenses
86
- granted to You under this License for that Work shall terminate
87
- as of the date such litigation is filed.
88
-
89
- 4. Redistribution. You may reproduce and distribute copies of the
90
- Work or Derivative Works thereof in any medium, with or without
91
- modifications, and in Source or Object form, provided that You
92
- meet the following conditions:
93
-
94
- (a) You must give any other recipients of the Work or
95
- Derivative Works a copy of this License; and
96
-
97
- (b) You must cause any modified files to carry prominent notices
98
- stating that You changed the files; and
99
-
100
- (c) You must retain, in the Source form of any Derivative Works
101
- that You distribute, all copyright, patent, trademark, and
102
- attribution notices from the Source form of the Work,
103
- excluding those notices that do not pertain to any part of
104
- the Derivative Works; and
105
-
106
- (d) If the Work includes a "NOTICE" text file as part of its
107
- distribution, then any Derivative Works that You distribute must
108
- include a readable copy of the attribution notices contained
109
- within such NOTICE file, excluding those notices that do not
110
- pertain to any part of the Derivative Works, in at least one
111
- of the following places: within a NOTICE text file distributed
112
- as part of the Derivative Works; within the Source form or
113
- documentation, if provided along with the Derivative Works; or,
114
- within a display generated by the Derivative Works, if and
115
- wherever such third-party notices normally appear. The contents
116
- of the NOTICE file are for informational purposes only and
117
- do not modify the License. You may add Your own attribution
118
- notices within Derivative Works that You distribute, alongside
119
- or as an addendum to the NOTICE text from the Work, provided
120
- that such additional attribution notices cannot be construed
121
- as modifying the License.
122
-
123
- You may add Your own copyright statement to Your modifications and
124
- may provide additional or different license terms and conditions
125
- for use, reproduction, or distribution of Your modifications, or
126
- for any such Derivative Works as a whole, provided Your use,
127
- reproduction, and distribution of the Work otherwise complies with
128
- the conditions stated in this License.
129
-
130
- 5. Submission of Contributions. Unless You explicitly state otherwise,
131
- any Contribution intentionally submitted for inclusion in the Work
132
- by You to the Licensor shall be under the terms and conditions of
133
- this License, without any additional terms or conditions.
134
- Notwithstanding the above, nothing herein shall supersede or modify
135
- the terms of any separate license agreement you may have executed
136
- with Licensor regarding such Contributions.
137
-
138
- 6. Trademarks. This License does not grant permission to use the trade
139
- names, trademarks, service marks, or product names of the Licensor,
140
- except as required for reasonable and customary use in describing the
141
- origin of the Work and reproducing the content of the NOTICE file.
142
-
143
- 7. Disclaimer of Warranty. Unless required by applicable law or
144
- agreed to in writing, Licensor provides the Work (and each
145
- Contributor provides its Contributions) on an "AS IS" BASIS,
146
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
- implied, including, without limitation, any warranties or conditions
148
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
- PARTICULAR PURPOSE. You are solely responsible for determining the
150
- appropriateness of using or redistributing the Work and assume any
151
- risks associated with Your exercise of permissions under this License.
152
-
153
- 8. Limitation of Liability. In no event and under no legal theory,
154
- whether in tort (including negligence), contract, or otherwise,
155
- unless required by applicable law (such as deliberate and grossly
156
- negligent acts) or agreed to in writing, shall any Contributor be
157
- liable to You for damages, including any direct, indirect, special,
158
- incidental, or consequential damages of any character arising as a
159
- result of this License or out of the use or inability to use the
160
- Work (including but not limited to damages for loss of goodwill,
161
- work stoppage, computer failure or malfunction, or any and all
162
- other commercial damages or losses), even if such Contributor
163
- has been advised of the possibility of such damages.
164
-
165
- 9. Accepting Warranty or Additional Liability. While redistributing
166
- the Work or Derivative Works thereof, You may choose to offer,
167
- and charge a fee for, acceptance of support, warranty, indemnity,
168
- or other liability obligations and/or rights consistent with this
169
- License. However, in accepting such obligations, You may act only
170
- on Your own behalf and on Your sole responsibility, not on behalf
171
- of any other Contributor, and only if You agree to indemnify,
172
- defend, and hold each Contributor harmless for any liability
173
- incurred by, or claims asserted against, such Contributor by reason
174
- of your accepting any such warranty or additional liability.
175
-
176
- END OF TERMS AND CONDITIONS
177
-
178
- APPENDIX: How to apply the Apache License to your work.
179
-
180
- To apply the Apache License to your work, attach the following
181
- boilerplate notice, with the fields enclosed by brackets "[]"
182
- replaced with your own identifying information. (Don't include
183
- the brackets!) The text should be enclosed in the appropriate
184
- comment syntax for the file format. We also recommend that a
185
- file or class name and description of purpose be included on the
186
- same "printed page" as the copyright notice for easier
187
- identification within third-party archives.
188
-
189
- Copyright [yyyy] [name of copyright owner]
190
-
191
- Licensed under the Apache License, Version 2.0 (the "License");
192
- you may not use this file except in compliance with the License.
193
- You may obtain a copy of the License at
194
-
195
- http://www.apache.org/licenses/LICENSE-2.0
196
-
197
- Unless required by applicable law or agreed to in writing, software
198
- distributed under the License is distributed on an "AS IS" BASIS,
199
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
- See the License for the specific language governing permissions and
201
- limitations under the License.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/LICENSES/LICENSE_DESCRIPT.txt DELETED
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) 2023-present, Descript
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/LICENSES/LICENSE_META.txt DELETED
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) Meta Platforms, Inc. and affiliates.
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/LICENSES/LICENSE_NVIDIA.txt DELETED
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) 2022 NVIDIA CORPORATION.
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/LICENSES/LICENSE_XTRANSFORMERS.txt DELETED
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) 2020 Phil Wang
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/README.md DELETED
@@ -1,157 +0,0 @@
1
- # stable-audio-tools
2
- Training and inference code for audio generation models
3
-
4
- # Install
5
-
6
- The library can be installed from PyPI with:
7
- ```bash
8
- $ pip install stable-audio-tools
9
- ```
10
-
11
- To run the training scripts or inference code, you'll want to clone this repository, navigate to the root, and run:
12
- ```bash
13
- $ pip install .
14
- ```
15
-
16
- # Requirements
17
- Requires PyTorch 2.0 or later for Flash Attention support
18
-
19
- Development for the repo is done in Python 3.8.10
20
-
21
- # Interface
22
-
23
- A basic Gradio interface is provided to test out trained models.
24
-
25
- For example, to create an interface for the [`stable-audio-open-1.0`](https://huggingface.co/stabilityai/stable-audio-open-1.0) model, once you've accepted the terms for the model on Hugging Face, you can run:
26
- ```bash
27
- $ python3 ./run_gradio.py --pretrained-name stabilityai/stable-audio-open-1.0
28
- ```
29
-
30
- The `run_gradio.py` script accepts the following command line arguments:
31
-
32
- - `--pretrained-name`
33
- - Hugging Face repository name for a Stable Audio Tools model
34
- - Will prioritize `model.safetensors` over `model.ckpt` in the repo
35
- - Optional, used in place of `model-config` and `ckpt-path` when using pre-trained model checkpoints on Hugging Face
36
- - `--model-config`
37
- - Path to the model config file for a local model
38
- - `--ckpt-path`
39
- - Path to unwrapped model checkpoint file for a local model
40
- - `--pretransform-ckpt-path`
41
- - Path to an unwrapped pretransform checkpoint, replaces the pretransform in the model, useful for testing out fine-tuned decoders
42
- - Optional
43
- - `--share`
44
- - If true, a publicly shareable link will be created for the Gradio demo
45
- - Optional
46
- - `--username` and `--password`
47
- - Used together to set a login for the Gradio demo
48
- - Optional
49
- - `--model-half`
50
- - If true, the model weights to half-precision
51
- - Optional
52
-
53
- # Training
54
-
55
- ## Prerequisites
56
- Before starting your training run, you'll need a model config file, as well as a dataset config file. For more information about those, refer to the Configurations section below
57
-
58
- The training code also requires a Weights & Biases account to log the training outputs and demos. Create an account and log in with:
59
- ```bash
60
- $ wandb login
61
- ```
62
-
63
- ## Start training
64
- To start a training run, run the `train.py` script in the repo root with:
65
- ```bash
66
- $ python3 ./train.py --dataset-config /path/to/dataset/config --model-config /path/to/model/config --name harmonai_train
67
- ```
68
-
69
- The `--name` parameter will set the project name for your Weights and Biases run.
70
-
71
- ## Training wrappers and model unwrapping
72
- `stable-audio-tools` uses PyTorch Lightning to facilitate multi-GPU and multi-node training.
73
-
74
- When a model is being trained, it is wrapped in a "training wrapper", which is a `pl.LightningModule` that contains all of the relevant objects needed only for training. That includes things like discriminators for autoencoders, EMA copies of models, and all of the optimizer states.
75
-
76
- The checkpoint files created during training include this training wrapper, which greatly increases the size of the checkpoint file.
77
-
78
- `unwrap_model.py` in the repo root will take in a wrapped model checkpoint and save a new checkpoint file including only the model itself.
79
-
80
- That can be run with from the repo root with:
81
- ```bash
82
- $ python3 ./unwrap_model.py --model-config /path/to/model/config --ckpt-path /path/to/wrapped/ckpt --name model_unwrap
83
- ```
84
-
85
- Unwrapped model checkpoints are required for:
86
- - Inference scripts
87
- - Using a model as a pretransform for another model (e.g. using an autoencoder model for latent diffusion)
88
- - Fine-tuning a pre-trained model with a modified configuration (i.e. partial initialization)
89
-
90
- ## Fine-tuning
91
- Fine-tuning a model involves continuning a training run from a pre-trained checkpoint.
92
-
93
- To continue a training run from a wrapped model checkpoint, you can pass in the checkpoint path to `train.py` with the `--ckpt-path` flag.
94
-
95
- To start a fresh training run using a pre-trained unwrapped model, you can pass in the unwrapped checkpoint to `train.py` with the `--pretrained-ckpt-path` flag.
96
-
97
- ## Additional training flags
98
-
99
- Additional optional flags for `train.py` include:
100
- - `--config-file`
101
- - The path to the defaults.ini file in the repo root, required if running `train.py` from a directory other than the repo root
102
- - `--pretransform-ckpt-path`
103
- - Used in various model types such as latent diffusion models to load a pre-trained autoencoder. Requires an unwrapped model checkpoint.
104
- - `--save-dir`
105
- - The directory in which to save the model checkpoints
106
- - `--checkpoint-every`
107
- - The number of steps between saved checkpoints.
108
- - *Default*: 10000
109
- - `--batch-size`
110
- - Number of samples per-GPU during training. Should be set as large as your GPU VRAM will allow.
111
- - *Default*: 8
112
- - `--num-gpus`
113
- - Number of GPUs per-node to use for training
114
- - *Default*: 1
115
- - `--num-nodes`
116
- - Number of GPU nodes being used for training
117
- - *Default*: 1
118
- - `--accum-batches`
119
- - Enables and sets the number of batches for gradient batch accumulation. Useful for increasing effective batch size when training on smaller GPUs.
120
- - `--strategy`
121
- - Multi-GPU strategy for distributed training. Setting to `deepspeed` will enable DeepSpeed ZeRO Stage 2.
122
- - *Default*: `ddp` if `--num_gpus` > 1, else None
123
- - `--precision`
124
- - floating-point precision to use during training
125
- - *Default*: 16
126
- - `--num-workers`
127
- - Number of CPU workers used by the data loader
128
- - `--seed`
129
- - RNG seed for PyTorch, helps with deterministic training
130
-
131
- # Configurations
132
- Training and inference code for `stable-audio-tools` is based around JSON configuration files that define model hyperparameters, training settings, and information about your training dataset.
133
-
134
- ## Model config
135
- The model config file defines all of the information needed to load a model for training or inference. It also contains the training configuration needed to fine-tune a model or train from scratch.
136
-
137
- The following properties are defined in the top level of the model configuration:
138
-
139
- - `model_type`
140
- - The type of model being defined, currently limited to one of `"autoencoder", "diffusion_uncond", "diffusion_cond", "diffusion_cond_inpaint", "diffusion_autoencoder", "lm"`.
141
- - `sample_size`
142
- - The length of the audio provided to the model during training, in samples. For diffusion models, this is also the raw audio sample length used for inference.
143
- - `sample_rate`
144
- - The sample rate of the audio provided to the model during training, and generated during inference, in Hz.
145
- - `audio_channels`
146
- - The number of channels of audio provided to the model during training, and generated during inference. Defaults to 2. Set to 1 for mono.
147
- - `model`
148
- - The specific configuration for the model being defined, varies based on `model_type`
149
- - `training`
150
- - The training configuration for the model, varies based on `model_type`. Provides parameters for training as well as demos.
151
-
152
- ## Dataset config
153
- `stable-audio-tools` currently supports two kinds of data sources: local directories of audio files, and WebDataset datasets stored in Amazon S3. More information can be found in [the dataset config documentation](docs/datasets.md)
154
-
155
- # Todo
156
- - [ ] Add troubleshooting section
157
- - [ ] Add contribution guidelines
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from .models.factory import create_model_from_config, create_model_from_config_path
2
- from .models.pretrained import get_pretrained_model
 
 
 
stable/build/lib/stable_audio_tools/data/__init__.py DELETED
File without changes
stable/build/lib/stable_audio_tools/data/dataset.py DELETED
@@ -1,654 +0,0 @@
1
- import importlib
2
- import numpy as np
3
- import io
4
- import os
5
- import posixpath
6
- import random
7
- import re
8
- import subprocess
9
- import time
10
- import torch
11
- import torchaudio
12
- import webdataset as wds
13
-
14
- from aeiou.core import is_silence
15
- from os import path
16
- from pedalboard.io import AudioFile
17
- from torchaudio import transforms as T
18
- from typing import Optional, Callable, List
19
-
20
- from .utils import Stereo, Mono, PhaseFlipper, PadCrop_Normalized_T
21
-
22
- AUDIO_KEYS = ("flac", "wav", "mp3", "m4a", "ogg", "opus")
23
-
24
- # fast_scandir implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py
25
-
26
- def fast_scandir(
27
- dir:str, # top-level directory at which to begin scanning
28
- ext:list, # list of allowed file extensions,
29
- #max_size = 1 * 1000 * 1000 * 1000 # Only files < 1 GB
30
- ):
31
- "very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243"
32
- subfolders, files = [], []
33
- ext = ['.'+x if x[0]!='.' else x for x in ext] # add starting period to extensions if needed
34
- try: # hope to avoid 'permission denied' by this try
35
- for f in os.scandir(dir):
36
- try: # 'hope to avoid too many levels of symbolic links' error
37
- if f.is_dir():
38
- subfolders.append(f.path)
39
- elif f.is_file():
40
- file_ext = os.path.splitext(f.name)[1].lower()
41
- is_hidden = os.path.basename(f.path).startswith(".")
42
-
43
- if file_ext in ext and not is_hidden:
44
- files.append(f.path)
45
- except:
46
- pass
47
- except:
48
- pass
49
-
50
- for dir in list(subfolders):
51
- sf, f = fast_scandir(dir, ext)
52
- subfolders.extend(sf)
53
- files.extend(f)
54
- return subfolders, files
55
-
56
- def keyword_scandir(
57
- dir: str, # top-level directory at which to begin scanning
58
- ext: list, # list of allowed file extensions
59
- keywords: list, # list of keywords to search for in the file name
60
- ):
61
- "very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243"
62
- subfolders, files = [], []
63
- # make keywords case insensitive
64
- keywords = [keyword.lower() for keyword in keywords]
65
- # add starting period to extensions if needed
66
- ext = ['.'+x if x[0] != '.' else x for x in ext]
67
- banned_words = ["paxheader", "__macosx"]
68
- try: # hope to avoid 'permission denied' by this try
69
- for f in os.scandir(dir):
70
- try: # 'hope to avoid too many levels of symbolic links' error
71
- if f.is_dir():
72
- subfolders.append(f.path)
73
- elif f.is_file():
74
- is_hidden = f.name.split("/")[-1][0] == '.'
75
- has_ext = os.path.splitext(f.name)[1].lower() in ext
76
- name_lower = f.name.lower()
77
- has_keyword = any(
78
- [keyword in name_lower for keyword in keywords])
79
- has_banned = any(
80
- [banned_word in name_lower for banned_word in banned_words])
81
- if has_ext and has_keyword and not has_banned and not is_hidden and not os.path.basename(f.path).startswith("._"):
82
- files.append(f.path)
83
- except:
84
- pass
85
- except:
86
- pass
87
-
88
- for dir in list(subfolders):
89
- sf, f = keyword_scandir(dir, ext, keywords)
90
- subfolders.extend(sf)
91
- files.extend(f)
92
- return subfolders, files
93
-
94
- def get_audio_filenames(
95
- paths: list, # directories in which to search
96
- keywords=None,
97
- exts=['.wav', '.mp3', '.flac', '.ogg', '.aif', '.opus']
98
- ):
99
- "recursively get a list of audio filenames"
100
- filenames = []
101
- if type(paths) is str:
102
- paths = [paths]
103
- for path in paths: # get a list of relevant filenames
104
- if keywords is not None:
105
- subfolders, files = keyword_scandir(path, exts, keywords)
106
- else:
107
- subfolders, files = fast_scandir(path, exts)
108
- filenames.extend(files)
109
- return filenames
110
-
111
- class LocalDatasetConfig:
112
- def __init__(
113
- self,
114
- id: str,
115
- path: str,
116
- custom_metadata_fn: Optional[Callable[[str], str]] = None
117
- ):
118
- self.id = id
119
- self.path = path
120
- self.custom_metadata_fn = custom_metadata_fn
121
-
122
- class SampleDataset(torch.utils.data.Dataset):
123
- def __init__(
124
- self,
125
- configs,
126
- sample_size=65536,
127
- sample_rate=48000,
128
- keywords=None,
129
- random_crop=True,
130
- force_channels="stereo"
131
- ):
132
- super().__init__()
133
- self.filenames = []
134
-
135
- self.augs = torch.nn.Sequential(
136
- PhaseFlipper(),
137
- )
138
-
139
- self.root_paths = []
140
-
141
- self.pad_crop = PadCrop_Normalized_T(sample_size, sample_rate, randomize=random_crop)
142
-
143
- self.force_channels = force_channels
144
-
145
- self.encoding = torch.nn.Sequential(
146
- Stereo() if self.force_channels == "stereo" else torch.nn.Identity(),
147
- Mono() if self.force_channels == "mono" else torch.nn.Identity(),
148
- )
149
-
150
- self.sr = sample_rate
151
-
152
- self.custom_metadata_fns = {}
153
-
154
- for config in configs:
155
- self.root_paths.append(config.path)
156
- self.filenames.extend(get_audio_filenames(config.path, keywords))
157
- if config.custom_metadata_fn is not None:
158
- self.custom_metadata_fns[config.path] = config.custom_metadata_fn
159
-
160
- print(f'Found {len(self.filenames)} files')
161
-
162
- def load_file(self, filename):
163
- ext = filename.split(".")[-1]
164
-
165
- if ext == "mp3":
166
- with AudioFile(filename) as f:
167
- audio = f.read(f.frames)
168
- audio = torch.from_numpy(audio)
169
- in_sr = f.samplerate
170
- else:
171
- audio, in_sr = torchaudio.load(filename, format=ext)
172
-
173
- if in_sr != self.sr:
174
- resample_tf = T.Resample(in_sr, self.sr)
175
- audio = resample_tf(audio)
176
-
177
- return audio
178
-
179
- def __len__(self):
180
- return len(self.filenames)
181
-
182
- def __getitem__(self, idx):
183
- audio_filename = self.filenames[idx]
184
- try:
185
- start_time = time.time()
186
- audio = self.load_file(audio_filename)
187
-
188
- audio, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio)
189
-
190
- # Run augmentations on this sample (including random crop)
191
- if self.augs is not None:
192
- audio = self.augs(audio)
193
-
194
- audio = audio.clamp(-1, 1)
195
-
196
- # Encode the file to assist in prediction
197
- if self.encoding is not None:
198
- audio = self.encoding(audio)
199
-
200
- info = {}
201
-
202
- info["path"] = audio_filename
203
-
204
- for root_path in self.root_paths:
205
- if root_path in audio_filename:
206
- info["relpath"] = path.relpath(audio_filename, root_path)
207
-
208
- info["timestamps"] = (t_start, t_end)
209
- info["seconds_start"] = seconds_start
210
- info["seconds_total"] = seconds_total
211
- info["padding_mask"] = padding_mask
212
-
213
- end_time = time.time()
214
-
215
- info["load_time"] = end_time - start_time
216
-
217
- for custom_md_path in self.custom_metadata_fns.keys():
218
- if custom_md_path in audio_filename:
219
- custom_metadata_fn = self.custom_metadata_fns[custom_md_path]
220
- custom_metadata = custom_metadata_fn(info, audio)
221
- info.update(custom_metadata)
222
-
223
- if "__reject__" in info and info["__reject__"]:
224
- return self[random.randrange(len(self))]
225
-
226
- return (audio, info)
227
- except Exception as e:
228
- print(f'Couldn\'t load file {audio_filename}: {e}')
229
- return self[random.randrange(len(self))]
230
-
231
- def group_by_keys(data, keys=wds.tariterators.base_plus_ext, lcase=True, suffixes=None, handler=None):
232
- """Return function over iterator that groups key, value pairs into samples.
233
- :param keys: function that splits the key into key and extension (base_plus_ext)
234
- :param lcase: convert suffixes to lower case (Default value = True)
235
- """
236
- current_sample = None
237
- for filesample in data:
238
- assert isinstance(filesample, dict)
239
- fname, value = filesample["fname"], filesample["data"]
240
- prefix, suffix = keys(fname)
241
- if wds.tariterators.trace:
242
- print(
243
- prefix,
244
- suffix,
245
- current_sample.keys() if isinstance(current_sample, dict) else None,
246
- )
247
- if prefix is None:
248
- continue
249
- if lcase:
250
- suffix = suffix.lower()
251
- if current_sample is None or prefix != current_sample["__key__"]:
252
- if wds.tariterators.valid_sample(current_sample):
253
- yield current_sample
254
- current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
255
- if suffix in current_sample:
256
- print(f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}")
257
- if suffixes is None or suffix in suffixes:
258
- current_sample[suffix] = value
259
- if wds.tariterators.valid_sample(current_sample):
260
- yield current_sample
261
-
262
- wds.tariterators.group_by_keys = group_by_keys
263
-
264
- # S3 code and WDS preprocessing code based on implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py
265
-
266
- def get_s3_contents(dataset_path, s3_url_prefix=None, filter='', recursive=True, debug=False, profile=None):
267
- """
268
- Returns a list of full S3 paths to files in a given S3 bucket and directory path.
269
- """
270
- # Ensure dataset_path ends with a trailing slash
271
- if dataset_path != '' and not dataset_path.endswith('/'):
272
- dataset_path += '/'
273
- # Use posixpath to construct the S3 URL path
274
- bucket_path = posixpath.join(s3_url_prefix or '', dataset_path)
275
- # Construct the `aws s3 ls` command
276
- cmd = ['aws', 's3', 'ls', bucket_path]
277
-
278
- if profile is not None:
279
- cmd.extend(['--profile', profile])
280
-
281
- if recursive:
282
- # Add the --recursive flag if requested
283
- cmd.append('--recursive')
284
-
285
- # Run the `aws s3 ls` command and capture the output
286
- run_ls = subprocess.run(cmd, capture_output=True, check=True)
287
- # Split the output into lines and strip whitespace from each line
288
- contents = run_ls.stdout.decode('utf-8').split('\n')
289
- contents = [x.strip() for x in contents if x]
290
- # Remove the timestamp from lines that begin with a timestamp
291
- contents = [re.sub(r'^\S+\s+\S+\s+\d+\s+', '', x)
292
- if re.match(r'^\S+\s+\S+\s+\d+\s+', x) else x for x in contents]
293
- # Construct a full S3 path for each file in the contents list
294
- contents = [posixpath.join(s3_url_prefix or '', x)
295
- for x in contents if not x.endswith('/')]
296
- # Apply the filter, if specified
297
- if filter:
298
- contents = [x for x in contents if filter in x]
299
- # Remove redundant directory names in the S3 URL
300
- if recursive:
301
- # Get the main directory name from the S3 URL
302
- main_dir = "/".join(bucket_path.split('/')[3:])
303
- # Remove the redundant directory names from each file path
304
- contents = [x.replace(f'{main_dir}', '').replace(
305
- '//', '/') for x in contents]
306
- # Print debugging information, if requested
307
- if debug:
308
- print("contents = \n", contents)
309
- # Return the list of S3 paths to files
310
- return contents
311
-
312
-
313
- def get_all_s3_urls(
314
- names=[], # list of all valid [LAION AudioDataset] dataset names
315
- # list of subsets you want from those datasets, e.g. ['train','valid']
316
- subsets=[''],
317
- s3_url_prefix=None, # prefix for those dataset names
318
- recursive=True, # recursively list all tar files in all subdirs
319
- filter_str='tar', # only grab files with this substring
320
- # print debugging info -- note: info displayed likely to change at dev's whims
321
- debug=False,
322
- profiles={}, # dictionary of profiles for each item in names, e.g. {'dataset1': 'profile1', 'dataset2': 'profile2'}
323
- ):
324
- "get urls of shards (tar files) for multiple datasets in one s3 bucket"
325
- urls = []
326
- for name in names:
327
- # If s3_url_prefix is not specified, assume the full S3 path is included in each element of the names list
328
- if s3_url_prefix is None:
329
- contents_str = name
330
- else:
331
- # Construct the S3 path using the s3_url_prefix and the current name value
332
- contents_str = posixpath.join(s3_url_prefix, name)
333
- if debug:
334
- print(f"get_all_s3_urls: {contents_str}:")
335
- for subset in subsets:
336
- subset_str = posixpath.join(contents_str, subset)
337
- if debug:
338
- print(f"subset_str = {subset_str}")
339
- # Get the list of tar files in the current subset directory
340
- profile = profiles.get(name, None)
341
- tar_list = get_s3_contents(
342
- subset_str, s3_url_prefix=None, recursive=recursive, filter=filter_str, debug=debug, profile=profile)
343
- for tar in tar_list:
344
- # Escape spaces and parentheses in the tar filename for use in the shell command
345
- tar = tar.replace(" ", "\ ").replace(
346
- "(", "\(").replace(")", "\)")
347
- # Construct the S3 path to the current tar file
348
- s3_path = posixpath.join(name, subset, tar) + " -"
349
- # Construct the AWS CLI command to download the current tar file
350
- if s3_url_prefix is None:
351
- request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {s3_path}"
352
- else:
353
- request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {posixpath.join(s3_url_prefix, s3_path)}"
354
- if profiles.get(name):
355
- request_str += f" --profile {profiles.get(name)}"
356
- if debug:
357
- print("request_str = ", request_str)
358
- # Add the constructed URL to the list of URLs
359
- urls.append(request_str)
360
- return urls
361
-
362
-
363
- def log_and_continue(exn):
364
- """Call in an exception handler to ignore any exception, isssue a warning, and continue."""
365
- print(f"Handling webdataset error ({repr(exn)}). Ignoring.")
366
- return True
367
-
368
-
369
- def is_valid_sample(sample):
370
- has_json = "json" in sample
371
- has_audio = "audio" in sample
372
- is_silent = is_silence(sample["audio"])
373
- is_rejected = "__reject__" in sample["json"] and sample["json"]["__reject__"]
374
-
375
- return has_json and has_audio and not is_silent and not is_rejected
376
-
377
- class S3DatasetConfig:
378
- def __init__(
379
- self,
380
- id: str,
381
- s3_path: str,
382
- custom_metadata_fn: Optional[Callable[[str], str]] = None,
383
- profile: Optional[str] = None,
384
- ):
385
- self.id = id
386
- self.path = s3_path
387
- self.custom_metadata_fn = custom_metadata_fn
388
- self.profile = profile
389
- self.urls = []
390
-
391
- def load_data_urls(self):
392
- self.urls = get_all_s3_urls(
393
- names=[self.path],
394
- s3_url_prefix=None,
395
- recursive=True,
396
- profiles={self.path: self.profile} if self.profile else {},
397
- )
398
-
399
- return self.urls
400
-
401
- class LocalWebDatasetConfig:
402
- def __init__(
403
- self,
404
- id: str,
405
- path: str,
406
- custom_metadata_fn: Optional[Callable[[str], str]] = None,
407
- profile: Optional[str] = None,
408
- ):
409
- self.id = id
410
- self.path = path
411
- self.custom_metadata_fn = custom_metadata_fn
412
- self.urls = []
413
-
414
- def load_data_urls(self):
415
-
416
- self.urls = fast_scandir(self.path, ["tar"])[1]
417
-
418
- return self.urls
419
-
420
- def audio_decoder(key, value):
421
- # Get file extension from key
422
- ext = key.split(".")[-1]
423
-
424
- if ext in AUDIO_KEYS:
425
- return torchaudio.load(io.BytesIO(value))
426
- else:
427
- return None
428
-
429
- def collation_fn(samples):
430
- batched = list(zip(*samples))
431
- result = []
432
- for b in batched:
433
- if isinstance(b[0], (int, float)):
434
- b = np.array(b)
435
- elif isinstance(b[0], torch.Tensor):
436
- b = torch.stack(b)
437
- elif isinstance(b[0], np.ndarray):
438
- b = np.array(b)
439
- else:
440
- b = b
441
- result.append(b)
442
- return result
443
-
444
- class WebDatasetDataLoader():
445
- def __init__(
446
- self,
447
- datasets: List[S3DatasetConfig],
448
- batch_size,
449
- sample_size,
450
- sample_rate=48000,
451
- num_workers=8,
452
- epoch_steps=1000,
453
- random_crop=True,
454
- force_channels="stereo",
455
- augment_phase=True,
456
- **data_loader_kwargs
457
- ):
458
-
459
- self.datasets = datasets
460
-
461
- self.sample_size = sample_size
462
- self.sample_rate = sample_rate
463
- self.random_crop = random_crop
464
- self.force_channels = force_channels
465
- self.augment_phase = augment_phase
466
-
467
- urls = [dataset.load_data_urls() for dataset in datasets]
468
-
469
- # Flatten the list of lists of URLs
470
- urls = [url for dataset_urls in urls for url in dataset_urls]
471
-
472
- # Shuffle the urls
473
- random.shuffle(urls)
474
-
475
- self.dataset = wds.DataPipeline(
476
- wds.ResampledShards(urls),
477
- wds.tarfile_to_samples(handler=log_and_continue),
478
- wds.decode(audio_decoder, handler=log_and_continue),
479
- wds.map(self.wds_preprocess, handler=log_and_continue),
480
- wds.select(is_valid_sample),
481
- wds.to_tuple("audio", "json", handler=log_and_continue),
482
- #wds.shuffle(bufsize=1000, initial=5000),
483
- wds.batched(batch_size, partial=False, collation_fn=collation_fn),
484
- ).with_epoch(epoch_steps//num_workers if num_workers > 0 else epoch_steps)
485
-
486
- self.data_loader = wds.WebLoader(self.dataset, num_workers=num_workers, **data_loader_kwargs)
487
-
488
- def wds_preprocess(self, sample):
489
-
490
- found_key, rewrite_key = '', ''
491
- for k, v in sample.items(): # print the all entries in dict
492
- for akey in AUDIO_KEYS:
493
- if k.endswith(akey):
494
- # to rename long/weird key with its simpler counterpart
495
- found_key, rewrite_key = k, akey
496
- break
497
- if '' != found_key:
498
- break
499
- if '' == found_key: # got no audio!
500
- return None # try returning None to tell WebDataset to skip this one
501
-
502
- audio, in_sr = sample[found_key]
503
- if in_sr != self.sample_rate:
504
- resample_tf = T.Resample(in_sr, self.sample_rate)
505
- audio = resample_tf(audio)
506
-
507
- if self.sample_size is not None:
508
- # Pad/crop and get the relative timestamp
509
- pad_crop = PadCrop_Normalized_T(
510
- self.sample_size, randomize=self.random_crop, sample_rate=self.sample_rate)
511
- audio, t_start, t_end, seconds_start, seconds_total, padding_mask = pad_crop(
512
- audio)
513
- sample["json"]["seconds_start"] = seconds_start
514
- sample["json"]["seconds_total"] = seconds_total
515
- sample["json"]["padding_mask"] = padding_mask
516
- else:
517
- t_start, t_end = 0, 1
518
-
519
- # Check if audio is length zero, initialize to a single zero if so
520
- if audio.shape[-1] == 0:
521
- audio = torch.zeros(1, 1)
522
-
523
- # Make the audio stereo and augment by randomly inverting phase
524
- augs = torch.nn.Sequential(
525
- Stereo() if self.force_channels == "stereo" else torch.nn.Identity(),
526
- Mono() if self.force_channels == "mono" else torch.nn.Identity(),
527
- PhaseFlipper() if self.augment_phase else torch.nn.Identity()
528
- )
529
-
530
- audio = augs(audio)
531
-
532
- sample["json"]["timestamps"] = (t_start, t_end)
533
-
534
- if "text" in sample["json"]:
535
- sample["json"]["prompt"] = sample["json"]["text"]
536
-
537
- # Check for custom metadata functions
538
- for dataset in self.datasets:
539
- if dataset.custom_metadata_fn is None:
540
- continue
541
-
542
- if dataset.path in sample["__url__"]:
543
- custom_metadata = dataset.custom_metadata_fn(sample["json"], audio)
544
- sample["json"].update(custom_metadata)
545
-
546
- if found_key != rewrite_key: # rename long/weird key with its simpler counterpart
547
- del sample[found_key]
548
-
549
- sample["audio"] = audio
550
-
551
- # Add audio to the metadata as well for conditioning
552
- sample["json"]["audio"] = audio
553
-
554
- return sample
555
-
556
- def create_dataloader_from_config(dataset_config, batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4):
557
-
558
- dataset_type = dataset_config.get("dataset_type", None)
559
-
560
- assert dataset_type is not None, "Dataset type must be specified in dataset config"
561
-
562
- if audio_channels == 1:
563
- force_channels = "mono"
564
- else:
565
- force_channels = "stereo"
566
-
567
- if dataset_type == "audio_dir":
568
-
569
- audio_dir_configs = dataset_config.get("datasets", None)
570
-
571
- assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]"
572
-
573
- configs = []
574
-
575
- for audio_dir_config in audio_dir_configs:
576
- audio_dir_path = audio_dir_config.get("path", None)
577
- assert audio_dir_path is not None, "Path must be set for local audio directory configuration"
578
-
579
- custom_metadata_fn = None
580
- custom_metadata_module_path = audio_dir_config.get("custom_metadata_module", None)
581
-
582
- if custom_metadata_module_path is not None:
583
- spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path)
584
- metadata_module = importlib.util.module_from_spec(spec)
585
- spec.loader.exec_module(metadata_module)
586
-
587
- custom_metadata_fn = metadata_module.get_custom_metadata
588
-
589
- configs.append(
590
- LocalDatasetConfig(
591
- id=audio_dir_config["id"],
592
- path=audio_dir_path,
593
- custom_metadata_fn=custom_metadata_fn
594
- )
595
- )
596
-
597
- train_set = SampleDataset(
598
- configs,
599
- sample_rate=sample_rate,
600
- sample_size=sample_size,
601
- random_crop=dataset_config.get("random_crop", True),
602
- force_channels=force_channels
603
- )
604
-
605
- return torch.utils.data.DataLoader(train_set, batch_size, shuffle=True,
606
- num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn)
607
-
608
- elif dataset_type in ["s3", "wds"]: # Support "s3" type for backwards compatibility
609
- wds_configs = []
610
-
611
- for wds_config in dataset_config["datasets"]:
612
-
613
- custom_metadata_fn = None
614
- custom_metadata_module_path = wds_config.get("custom_metadata_module", None)
615
-
616
- if custom_metadata_module_path is not None:
617
- spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path)
618
- metadata_module = importlib.util.module_from_spec(spec)
619
- spec.loader.exec_module(metadata_module)
620
-
621
- custom_metadata_fn = metadata_module.get_custom_metadata
622
-
623
- if "s3_path" in wds_config:
624
-
625
- wds_configs.append(
626
- S3DatasetConfig(
627
- id=wds_config["id"],
628
- s3_path=wds_config["s3_path"],
629
- custom_metadata_fn=custom_metadata_fn,
630
- profile=wds_config.get("profile", None),
631
- )
632
- )
633
-
634
- elif "path" in wds_config:
635
-
636
- wds_configs.append(
637
- LocalWebDatasetConfig(
638
- id=wds_config["id"],
639
- path=wds_config["path"],
640
- custom_metadata_fn=custom_metadata_fn
641
- )
642
- )
643
-
644
- return WebDatasetDataLoader(
645
- wds_configs,
646
- sample_rate=sample_rate,
647
- sample_size=sample_size,
648
- batch_size=batch_size,
649
- random_crop=dataset_config.get("random_crop", True),
650
- num_workers=num_workers,
651
- persistent_workers=True,
652
- force_channels=force_channels,
653
- epoch_steps=dataset_config.get("epoch_steps", 2000)
654
- ).data_loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/data/utils.py DELETED
@@ -1,96 +0,0 @@
1
- import math
2
- import random
3
- import torch
4
-
5
- from torch import nn
6
- from typing import Tuple
7
-
8
- class PadCrop(nn.Module):
9
- def __init__(self, n_samples, randomize=True):
10
- super().__init__()
11
- self.n_samples = n_samples
12
- self.randomize = randomize
13
-
14
- def __call__(self, signal):
15
- n, s = signal.shape
16
- start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item()
17
- end = start + self.n_samples
18
- output = signal.new_zeros([n, self.n_samples])
19
- output[:, :min(s, self.n_samples)] = signal[:, start:end]
20
- return output
21
-
22
- class PadCrop_Normalized_T(nn.Module):
23
-
24
- def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
25
-
26
- super().__init__()
27
-
28
- self.n_samples = n_samples
29
- self.sample_rate = sample_rate
30
- self.randomize = randomize
31
-
32
- def __call__(self, source: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int]:
33
-
34
- n_channels, n_samples = source.shape
35
-
36
- # If the audio is shorter than the desired length, pad it
37
- upper_bound = max(0, n_samples - self.n_samples)
38
-
39
- # If randomize is False, always start at the beginning of the audio
40
- offset = 0
41
- if(self.randomize and n_samples > self.n_samples):
42
- offset = random.randint(0, upper_bound)
43
-
44
- # Calculate the start and end times of the chunk
45
- t_start = offset / (upper_bound + self.n_samples)
46
- t_end = (offset + self.n_samples) / (upper_bound + self.n_samples)
47
-
48
- # Create the chunk
49
- chunk = source.new_zeros([n_channels, self.n_samples])
50
-
51
- # Copy the audio into the chunk
52
- chunk[:, :min(n_samples, self.n_samples)] = source[:, offset:offset + self.n_samples]
53
-
54
- # Calculate the start and end times of the chunk in seconds
55
- seconds_start = math.floor(offset / self.sample_rate)
56
- seconds_total = math.ceil(n_samples / self.sample_rate)
57
-
58
- # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't
59
- padding_mask = torch.zeros([self.n_samples])
60
- padding_mask[:min(n_samples, self.n_samples)] = 1
61
-
62
-
63
- return (
64
- chunk,
65
- t_start,
66
- t_end,
67
- seconds_start,
68
- seconds_total,
69
- padding_mask
70
- )
71
-
72
- class PhaseFlipper(nn.Module):
73
- "Randomly invert the phase of a signal"
74
- def __init__(self, p=0.5):
75
- super().__init__()
76
- self.p = p
77
- def __call__(self, signal):
78
- return -signal if (random.random() < self.p) else signal
79
-
80
- class Mono(nn.Module):
81
- def __call__(self, signal):
82
- return torch.mean(signal, dim=0, keepdims=True) if len(signal.shape) > 1 else signal
83
-
84
- class Stereo(nn.Module):
85
- def __call__(self, signal):
86
- signal_shape = signal.shape
87
- # Check if it's mono
88
- if len(signal_shape) == 1: # s -> 2, s
89
- signal = signal.unsqueeze(0).repeat(2, 1)
90
- elif len(signal_shape) == 2:
91
- if signal_shape[0] == 1: #1, s -> 2, s
92
- signal = signal.repeat(2, 1)
93
- elif signal_shape[0] > 2: #?, s -> 2,s
94
- signal = signal[:2, :]
95
-
96
- return signal
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/inference/__init__.py DELETED
File without changes
stable/build/lib/stable_audio_tools/inference/generation.py DELETED
@@ -1,274 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import typing as tp
4
- import math
5
- from torchaudio import transforms as T
6
-
7
- from .utils import prepare_audio
8
- from .sampling import sample, sample_k, sample_rf
9
- from ..data.utils import PadCrop
10
-
11
- def generate_diffusion_uncond(
12
- model,
13
- steps: int = 250,
14
- batch_size: int = 1,
15
- sample_size: int = 2097152,
16
- seed: int = -1,
17
- device: str = "cuda",
18
- init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None,
19
- init_noise_level: float = 1.0,
20
- return_latents = False,
21
- **sampler_kwargs
22
- ) -> torch.Tensor:
23
-
24
- # The length of the output in audio samples
25
- audio_sample_size = sample_size
26
-
27
- # If this is latent diffusion, change sample_size instead to the downsampled latent size
28
- if model.pretransform is not None:
29
- sample_size = sample_size // model.pretransform.downsampling_ratio
30
-
31
- # Seed
32
- # The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed.
33
- seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32)
34
- print(seed)
35
- torch.manual_seed(seed)
36
- # Define the initial noise immediately after setting the seed
37
- noise = torch.randn([batch_size, model.io_channels, sample_size], device=device)
38
-
39
- if init_audio is not None:
40
- # The user supplied some initial audio (for inpainting or variation). Let us prepare the input audio.
41
- in_sr, init_audio = init_audio
42
-
43
- io_channels = model.io_channels
44
-
45
- # For latent models, set the io_channels to the autoencoder's io_channels
46
- if model.pretransform is not None:
47
- io_channels = model.pretransform.io_channels
48
-
49
- # Prepare the initial audio for use by the model
50
- init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device)
51
-
52
- # For latent models, encode the initial audio into latents
53
- if model.pretransform is not None:
54
- init_audio = model.pretransform.encode(init_audio)
55
-
56
- init_audio = init_audio.repeat(batch_size, 1, 1)
57
- else:
58
- # The user did not supply any initial audio for inpainting or variation. Generate new output from scratch.
59
- init_audio = None
60
- init_noise_level = None
61
-
62
- # Inpainting mask
63
-
64
- if init_audio is not None:
65
- # variations
66
- sampler_kwargs["sigma_max"] = init_noise_level
67
- mask = None
68
- else:
69
- mask = None
70
-
71
- # Now the generative AI part:
72
-
73
- diff_objective = model.diffusion_objective
74
-
75
- if diff_objective == "v":
76
- # k-diffusion denoising process go!
77
- sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, device=device)
78
- elif diff_objective == "rectified_flow":
79
- sampled = sample_rf(model.model, noise, init_data=init_audio, steps=steps, **sampler_kwargs, device=device)
80
-
81
- # Denoising process done.
82
- # If this is latent diffusion, decode latents back into audio
83
- if model.pretransform is not None and not return_latents:
84
- sampled = model.pretransform.decode(sampled)
85
-
86
- # Return audio
87
- return sampled
88
-
89
-
90
- def generate_diffusion_cond(
91
- model,
92
- steps: int = 250,
93
- cfg_scale=6,
94
- conditioning: dict = None,
95
- conditioning_tensors: tp.Optional[dict] = None,
96
- negative_conditioning: dict = None,
97
- negative_conditioning_tensors: tp.Optional[dict] = None,
98
- batch_size: int = 1,
99
- sample_size: int = 2097152,
100
- sample_rate: int = 48000,
101
- seed: int = -1,
102
- device: str = "cuda",
103
- init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None,
104
- init_noise_level: float = 1.0,
105
- mask_args: dict = None,
106
- return_latents = False,
107
- **sampler_kwargs
108
- ) -> torch.Tensor:
109
- """
110
- Generate audio from a prompt using a diffusion model.
111
-
112
- Args:
113
- model: The diffusion model to use for generation.
114
- steps: The number of diffusion steps to use.
115
- cfg_scale: Classifier-free guidance scale
116
- conditioning: A dictionary of conditioning parameters to use for generation.
117
- conditioning_tensors: A dictionary of precomputed conditioning tensors to use for generation.
118
- batch_size: The batch size to use for generation.
119
- sample_size: The length of the audio to generate, in samples.
120
- sample_rate: The sample rate of the audio to generate (Deprecated, now pulled from the model directly)
121
- seed: The random seed to use for generation, or -1 to use a random seed.
122
- device: The device to use for generation.
123
- init_audio: A tuple of (sample_rate, audio) to use as the initial audio for generation.
124
- init_noise_level: The noise level to use when generating from an initial audio sample.
125
- return_latents: Whether to return the latents used for generation instead of the decoded audio.
126
- **sampler_kwargs: Additional keyword arguments to pass to the sampler.
127
- """
128
-
129
- # The length of the output in audio samples
130
- audio_sample_size = sample_size
131
-
132
- # If this is latent diffusion, change sample_size instead to the downsampled latent size
133
- if model.pretransform is not None:
134
- sample_size = sample_size // model.pretransform.downsampling_ratio
135
-
136
- # Seed
137
- # The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed.
138
- seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32)
139
- print(seed)
140
- torch.manual_seed(seed)
141
- # Define the initial noise immediately after setting the seed
142
- noise = torch.randn([batch_size, model.io_channels, sample_size], device=device)
143
-
144
- torch.backends.cuda.matmul.allow_tf32 = False
145
- torch.backends.cudnn.allow_tf32 = False
146
- torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
147
- torch.backends.cudnn.benchmark = False
148
-
149
- # Conditioning
150
- assert conditioning is not None or conditioning_tensors is not None, "Must provide either conditioning or conditioning_tensors"
151
- if conditioning_tensors is None:
152
- conditioning_tensors = model.conditioner(conditioning, device)
153
- conditioning_inputs = model.get_conditioning_inputs(conditioning_tensors)
154
-
155
- if negative_conditioning is not None or negative_conditioning_tensors is not None:
156
-
157
- if negative_conditioning_tensors is None:
158
- negative_conditioning_tensors = model.conditioner(negative_conditioning, device)
159
-
160
- negative_conditioning_tensors = model.get_conditioning_inputs(negative_conditioning_tensors, negative=True)
161
- else:
162
- negative_conditioning_tensors = {}
163
-
164
- if init_audio is not None:
165
- # The user supplied some initial audio (for inpainting or variation). Let us prepare the input audio.
166
- in_sr, init_audio = init_audio
167
-
168
- io_channels = model.io_channels
169
-
170
- # For latent models, set the io_channels to the autoencoder's io_channels
171
- if model.pretransform is not None:
172
- io_channels = model.pretransform.io_channels
173
-
174
- # Prepare the initial audio for use by the model
175
- init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device)
176
-
177
- # For latent models, encode the initial audio into latents
178
- if model.pretransform is not None:
179
- init_audio = model.pretransform.encode(init_audio)
180
-
181
- init_audio = init_audio.repeat(batch_size, 1, 1)
182
- else:
183
- # The user did not supply any initial audio for inpainting or variation. Generate new output from scratch.
184
- init_audio = None
185
- init_noise_level = None
186
- mask_args = None
187
-
188
- # Inpainting mask
189
- if init_audio is not None and mask_args is not None:
190
- # Cut and paste init_audio according to cropfrom, pastefrom, pasteto
191
- # This is helpful for forward and reverse outpainting
192
- cropfrom = math.floor(mask_args["cropfrom"]/100.0 * sample_size)
193
- pastefrom = math.floor(mask_args["pastefrom"]/100.0 * sample_size)
194
- pasteto = math.ceil(mask_args["pasteto"]/100.0 * sample_size)
195
- assert pastefrom < pasteto, "Paste From should be less than Paste To"
196
- croplen = pasteto - pastefrom
197
- if cropfrom + croplen > sample_size:
198
- croplen = sample_size - cropfrom
199
- cropto = cropfrom + croplen
200
- pasteto = pastefrom + croplen
201
- cutpaste = init_audio.new_zeros(init_audio.shape)
202
- cutpaste[:, :, pastefrom:pasteto] = init_audio[:,:,cropfrom:cropto]
203
- #print(cropfrom, cropto, pastefrom, pasteto)
204
- init_audio = cutpaste
205
- # Build a soft mask (list of floats 0 to 1, the size of the latent) from the given args
206
- mask = build_mask(sample_size, mask_args)
207
- mask = mask.to(device)
208
- elif init_audio is not None and mask_args is None:
209
- # variations
210
- sampler_kwargs["sigma_max"] = init_noise_level
211
- mask = None
212
- else:
213
- mask = None
214
-
215
- model_dtype = next(model.model.parameters()).dtype
216
- noise = noise.type(model_dtype)
217
- conditioning_inputs = {k: v.type(model_dtype) if v is not None else v for k, v in conditioning_inputs.items()}
218
- # Now the generative AI part:
219
- # k-diffusion denoising process go!
220
-
221
- diff_objective = model.diffusion_objective
222
-
223
- if diff_objective == "v":
224
- # k-diffusion denoising process go!
225
- sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, **conditioning_inputs, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device)
226
- elif diff_objective == "rectified_flow":
227
-
228
- if "sigma_min" in sampler_kwargs:
229
- del sampler_kwargs["sigma_min"]
230
-
231
- if "sampler_type" in sampler_kwargs:
232
- del sampler_kwargs["sampler_type"]
233
-
234
- sampled = sample_rf(model.model, noise, init_data=init_audio, steps=steps, **sampler_kwargs, **conditioning_inputs, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device)
235
-
236
- # v-diffusion:
237
- #sampled = sample(model.model, noise, steps, 0, **conditioning_tensors, embedding_scale=cfg_scale)
238
- del noise
239
- del conditioning_tensors
240
- del conditioning_inputs
241
- torch.cuda.empty_cache()
242
- # Denoising process done.
243
- # If this is latent diffusion, decode latents back into audio
244
- if model.pretransform is not None and not return_latents:
245
- #cast sampled latents to pretransform dtype
246
- sampled = sampled.to(next(model.pretransform.parameters()).dtype)
247
- sampled = model.pretransform.decode(sampled)
248
-
249
- # Return audio
250
- return sampled
251
-
252
- # builds a softmask given the parameters
253
- # returns array of values 0 to 1, size sample_size, where 0 means noise / fresh generation, 1 means keep the input audio,
254
- # and anything between is a mixture of old/new
255
- # ideally 0.5 is half/half mixture but i haven't figured this out yet
256
- def build_mask(sample_size, mask_args):
257
- maskstart = math.floor(mask_args["maskstart"]/100.0 * sample_size)
258
- maskend = math.ceil(mask_args["maskend"]/100.0 * sample_size)
259
- softnessL = round(mask_args["softnessL"]/100.0 * sample_size)
260
- softnessR = round(mask_args["softnessR"]/100.0 * sample_size)
261
- marination = mask_args["marination"]
262
- # use hann windows for softening the transition (i don't know if this is correct)
263
- hannL = torch.hann_window(softnessL*2, periodic=False)[:softnessL]
264
- hannR = torch.hann_window(softnessR*2, periodic=False)[softnessR:]
265
- # build the mask.
266
- mask = torch.zeros((sample_size))
267
- mask[maskstart:maskend] = 1
268
- mask[maskstart:maskstart+softnessL] = hannL
269
- mask[maskend-softnessR:maskend] = hannR
270
- # marination finishes the inpainting early in the denoising schedule, and lets audio get changed in the final rounds
271
- if marination > 0:
272
- mask = mask * (1-marination)
273
- #print(mask)
274
- return mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/inference/sampling.py DELETED
@@ -1,232 +0,0 @@
1
- import torch
2
- import math
3
- from tqdm import trange, tqdm
4
-
5
- import k_diffusion as K
6
-
7
- # Define the noise schedule and sampling loop
8
- def get_alphas_sigmas(t):
9
- """Returns the scaling factors for the clean image (alpha) and for the
10
- noise (sigma), given a timestep."""
11
- return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
12
-
13
- def alpha_sigma_to_t(alpha, sigma):
14
- """Returns a timestep, given the scaling factors for the clean image and for
15
- the noise."""
16
- return torch.atan2(sigma, alpha) / math.pi * 2
17
-
18
- def t_to_alpha_sigma(t):
19
- """Returns the scaling factors for the clean image and for the noise, given
20
- a timestep."""
21
- return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
22
-
23
-
24
- @torch.no_grad()
25
- def sample_discrete_euler(model, x, steps, sigma_max=1, **extra_args):
26
- """Draws samples from a model given starting noise. Euler method"""
27
-
28
- # Make tensor of ones to broadcast the single t values
29
- ts = x.new_ones([x.shape[0]])
30
-
31
- # Create the noise schedule
32
- t = torch.linspace(sigma_max, 0, steps + 1)
33
-
34
- #alphas, sigmas = 1-t, t
35
-
36
- for t_curr, t_prev in tqdm(zip(t[:-1], t[1:])):
37
- # Broadcast the current timestep to the correct shape
38
- t_curr_tensor = t_curr * torch.ones(
39
- (x.shape[0],), dtype=x.dtype, device=x.device
40
- )
41
- dt = t_prev - t_curr # we solve backwards in our formulation
42
- x = x + dt * model(x, t_curr_tensor, **extra_args) #.denoise(x, denoiser, t_curr_tensor, cond, uc)
43
-
44
- # If we are on the last timestep, output the denoised image
45
- return x
46
-
47
- @torch.no_grad()
48
- def sample(model, x, steps, eta, **extra_args):
49
- """Draws samples from a model given starting noise. v-diffusion"""
50
- ts = x.new_ones([x.shape[0]])
51
-
52
- # Create the noise schedule
53
- t = torch.linspace(1, 0, steps + 1)[:-1]
54
-
55
- alphas, sigmas = get_alphas_sigmas(t)
56
-
57
- # The sampling loop
58
- for i in trange(steps):
59
-
60
- # Get the model output (v, the predicted velocity)
61
- with torch.cuda.amp.autocast():
62
- v = model(x, ts * t[i], **extra_args).float()
63
-
64
- # Predict the noise and the denoised image
65
- pred = x * alphas[i] - v * sigmas[i]
66
- eps = x * sigmas[i] + v * alphas[i]
67
-
68
- # If we are not on the last timestep, compute the noisy image for the
69
- # next timestep.
70
- if i < steps - 1:
71
- # If eta > 0, adjust the scaling factor for the predicted noise
72
- # downward according to the amount of additional noise to add
73
- ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \
74
- (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt()
75
- adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt()
76
-
77
- # Recombine the predicted noise and predicted denoised image in the
78
- # correct proportions for the next step
79
- x = pred * alphas[i + 1] + eps * adjusted_sigma
80
-
81
- # Add the correct amount of fresh noise
82
- if eta:
83
- x += torch.randn_like(x) * ddim_sigma
84
-
85
- # If we are on the last timestep, output the denoised image
86
- return pred
87
-
88
- # Soft mask inpainting is just shrinking hard (binary) mask inpainting
89
- # Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step
90
- def get_bmask(i, steps, mask):
91
- strength = (i+1)/(steps)
92
- # convert to binary mask
93
- bmask = torch.where(mask<=strength,1,0)
94
- return bmask
95
-
96
- def make_cond_model_fn(model, cond_fn):
97
- def cond_model_fn(x, sigma, **kwargs):
98
- with torch.enable_grad():
99
- x = x.detach().requires_grad_()
100
- denoised = model(x, sigma, **kwargs)
101
- cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach()
102
- cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim)
103
- return cond_denoised
104
- return cond_model_fn
105
-
106
- # Uses k-diffusion from https://github.com/crowsonkb/k-diffusion
107
- # init_data is init_audio as latents (if this is latent diffusion)
108
- # For sampling, set both init_data and mask to None
109
- # For variations, set init_data
110
- # For inpainting, set both init_data & mask
111
- def sample_k(
112
- model_fn,
113
- noise,
114
- init_data=None,
115
- mask=None,
116
- steps=100,
117
- sampler_type="dpmpp-2m-sde",
118
- sigma_min=0.5,
119
- sigma_max=50,
120
- rho=1.0, device="cuda",
121
- callback=None,
122
- cond_fn=None,
123
- **extra_args
124
- ):
125
-
126
- denoiser = K.external.VDenoiser(model_fn)
127
-
128
- if cond_fn is not None:
129
- denoiser = make_cond_model_fn(denoiser, cond_fn)
130
-
131
- # Make the list of sigmas. Sigma values are scalars related to the amount of noise each denoising step has
132
- sigmas = K.sampling.get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device)
133
- # Scale the initial noise by sigma
134
- noise = noise * sigmas[0]
135
-
136
- wrapped_callback = callback
137
-
138
- if mask is None and init_data is not None:
139
- # VARIATION (no inpainting)
140
- # set the initial latent to the init_data, and noise it with initial sigma
141
- x = init_data + noise
142
- elif mask is not None and init_data is not None:
143
- # INPAINTING
144
- bmask = get_bmask(0, steps, mask)
145
- # initial noising
146
- input_noised = init_data + noise
147
- # set the initial latent to a mix of init_data and noise, based on step 0's binary mask
148
- x = input_noised * bmask + noise * (1-bmask)
149
- # define the inpainting callback function (Note: side effects, it mutates x)
150
- # See https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L596C13-L596C105
151
- # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
152
- # This is called immediately after `denoised = model(x, sigmas[i] * s_in, **extra_args)`
153
- def inpainting_callback(args):
154
- i = args["i"]
155
- x = args["x"]
156
- sigma = args["sigma"]
157
- #denoised = args["denoised"]
158
- # noise the init_data input with this step's appropriate amount of noise
159
- input_noised = init_data + torch.randn_like(init_data) * sigma
160
- # shrinking hard mask
161
- bmask = get_bmask(i, steps, mask)
162
- # mix input_noise with x, using binary mask
163
- new_x = input_noised * bmask + x * (1-bmask)
164
- # mutate x
165
- x[:,:,:] = new_x[:,:,:]
166
- # wrap together the inpainting callback and the user-submitted callback.
167
- if callback is None:
168
- wrapped_callback = inpainting_callback
169
- else:
170
- wrapped_callback = lambda args: (inpainting_callback(args), callback(args))
171
- else:
172
- # SAMPLING
173
- # set the initial latent to noise
174
- x = noise
175
-
176
-
177
- with torch.cuda.amp.autocast():
178
- if sampler_type == "k-heun":
179
- return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
180
- elif sampler_type == "k-lms":
181
- return K.sampling.sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
182
- elif sampler_type == "k-dpmpp-2s-ancestral":
183
- return K.sampling.sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
184
- elif sampler_type == "k-dpm-2":
185
- return K.sampling.sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
186
- elif sampler_type == "k-dpm-fast":
187
- return K.sampling.sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args)
188
- elif sampler_type == "k-dpm-adaptive":
189
- return K.sampling.sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args)
190
- elif sampler_type == "dpmpp-2m-sde":
191
- return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
192
- elif sampler_type == "dpmpp-3m-sde":
193
- return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
194
-
195
- # Uses discrete Euler sampling for rectified flow models
196
- # init_data is init_audio as latents (if this is latent diffusion)
197
- # For sampling, set both init_data and mask to None
198
- # For variations, set init_data
199
- # For inpainting, set both init_data & mask
200
- def sample_rf(
201
- model_fn,
202
- noise,
203
- init_data=None,
204
- steps=100,
205
- sigma_max=1,
206
- device="cuda",
207
- callback=None,
208
- cond_fn=None,
209
- **extra_args
210
- ):
211
-
212
- if sigma_max > 1:
213
- sigma_max = 1
214
-
215
- if cond_fn is not None:
216
- denoiser = make_cond_model_fn(denoiser, cond_fn)
217
-
218
- wrapped_callback = callback
219
-
220
- if init_data is not None:
221
- # VARIATION (no inpainting)
222
- # Interpolate the init data and the noise for init audio
223
- x = init_data * (1 - sigma_max) + noise * sigma_max
224
- else:
225
- # SAMPLING
226
- # set the initial latent to noise
227
- x = noise
228
-
229
- with torch.cuda.amp.autocast():
230
- # TODO: Add callback support
231
- #return sample_discrete_euler(model_fn, x, steps, sigma_max, callback=wrapped_callback, **extra_args)
232
- return sample_discrete_euler(model_fn, x, steps, sigma_max, **extra_args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/inference/utils.py DELETED
@@ -1,35 +0,0 @@
1
- from ..data.utils import PadCrop
2
-
3
- from torchaudio import transforms as T
4
-
5
- def set_audio_channels(audio, target_channels):
6
- if target_channels == 1:
7
- # Convert to mono
8
- audio = audio.mean(1, keepdim=True)
9
- elif target_channels == 2:
10
- # Convert to stereo
11
- if audio.shape[1] == 1:
12
- audio = audio.repeat(1, 2, 1)
13
- elif audio.shape[1] > 2:
14
- audio = audio[:, :2, :]
15
- return audio
16
-
17
- def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
18
-
19
- audio = audio.to(device)
20
-
21
- if in_sr != target_sr:
22
- resample_tf = T.Resample(in_sr, target_sr).to(device)
23
- audio = resample_tf(audio)
24
-
25
- audio = PadCrop(target_length, randomize=False)(audio)
26
-
27
- # Add batch dimension
28
- if audio.dim() == 1:
29
- audio = audio.unsqueeze(0).unsqueeze(0)
30
- elif audio.dim() == 2:
31
- audio = audio.unsqueeze(0)
32
-
33
- audio = set_audio_channels(audio, target_channels)
34
-
35
- return audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/interface/__init__.py DELETED
File without changes
stable/build/lib/stable_audio_tools/interface/gradio.py DELETED
@@ -1,700 +0,0 @@
1
- import gc
2
- import platform
3
-
4
- import numpy as np
5
- import gradio as gr
6
- import json
7
- import torch
8
- import torchaudio
9
-
10
- from aeiou.viz import audio_spectrogram_image
11
- from einops import rearrange
12
- from safetensors.torch import load_file
13
- from torch.nn import functional as F
14
- from torchaudio import transforms as T
15
-
16
- from ..inference.generation import generate_diffusion_cond, generate_diffusion_uncond
17
- from ..models.factory import create_model_from_config
18
- from ..models.pretrained import get_pretrained_model
19
- from ..models.utils import load_ckpt_state_dict
20
- from ..inference.utils import prepare_audio
21
- from ..training.utils import copy_state_dict
22
-
23
- model = None
24
- sample_rate = 32000
25
- sample_size = 1920000
26
-
27
- def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device="cuda", model_half=False):
28
- global model, sample_rate, sample_size
29
-
30
- if pretrained_name is not None:
31
- print(f"Loading pretrained model {pretrained_name}")
32
- model, model_config = get_pretrained_model(pretrained_name)
33
-
34
- elif model_config is not None and model_ckpt_path is not None:
35
- print(f"Creating model from config")
36
- model = create_model_from_config(model_config)
37
-
38
- print(f"Loading model checkpoint from {model_ckpt_path}")
39
- # Load checkpoint
40
- copy_state_dict(model, load_ckpt_state_dict(model_ckpt_path))
41
- #model.load_state_dict(load_ckpt_state_dict(model_ckpt_path))
42
-
43
- sample_rate = model_config["sample_rate"]
44
- sample_size = model_config["sample_size"]
45
-
46
- if pretransform_ckpt_path is not None:
47
- print(f"Loading pretransform checkpoint from {pretransform_ckpt_path}")
48
- model.pretransform.load_state_dict(load_ckpt_state_dict(pretransform_ckpt_path), strict=False)
49
- print(f"Done loading pretransform")
50
-
51
- model.to(device).eval().requires_grad_(False)
52
-
53
- if model_half:
54
- model.to(torch.float16)
55
-
56
- print(f"Done loading model")
57
-
58
- return model, model_config
59
-
60
- def generate_cond(
61
- prompt,
62
- negative_prompt=None,
63
- seconds_start=0,
64
- seconds_total=30,
65
- cfg_scale=6.0,
66
- steps=250,
67
- preview_every=None,
68
- seed=-1,
69
- sampler_type="dpmpp-3m-sde",
70
- sigma_min=0.03,
71
- sigma_max=1000,
72
- cfg_rescale=0.0,
73
- use_init=False,
74
- init_audio=None,
75
- init_noise_level=1.0,
76
- mask_cropfrom=None,
77
- mask_pastefrom=None,
78
- mask_pasteto=None,
79
- mask_maskstart=None,
80
- mask_maskend=None,
81
- mask_softnessL=None,
82
- mask_softnessR=None,
83
- mask_marination=None,
84
- batch_size=1
85
- ):
86
-
87
- if torch.cuda.is_available():
88
- torch.cuda.empty_cache()
89
- gc.collect()
90
-
91
- print(f"Prompt: {prompt}")
92
-
93
- global preview_images
94
- preview_images = []
95
- if preview_every == 0:
96
- preview_every = None
97
-
98
- # Return fake stereo audio
99
- conditioning = [{"prompt": prompt, "seconds_start": seconds_start, "seconds_total": seconds_total}] * batch_size
100
-
101
- if negative_prompt:
102
- negative_conditioning = [{"prompt": negative_prompt, "seconds_start": seconds_start, "seconds_total": seconds_total}] * batch_size
103
- else:
104
- negative_conditioning = None
105
-
106
- #Get the device from the model
107
- device = next(model.parameters()).device
108
-
109
- seed = int(seed)
110
-
111
- if not use_init:
112
- init_audio = None
113
-
114
- input_sample_size = sample_size
115
-
116
- if init_audio is not None:
117
- in_sr, init_audio = init_audio
118
- # Turn into torch tensor, converting from int16 to float32
119
- init_audio = torch.from_numpy(init_audio).float().div(32767)
120
-
121
- if init_audio.dim() == 1:
122
- init_audio = init_audio.unsqueeze(0) # [1, n]
123
- elif init_audio.dim() == 2:
124
- init_audio = init_audio.transpose(0, 1) # [n, 2] -> [2, n]
125
-
126
- if in_sr != sample_rate:
127
- resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device)
128
- init_audio = resample_tf(init_audio)
129
-
130
- audio_length = init_audio.shape[-1]
131
-
132
- if audio_length > sample_size:
133
-
134
- input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length
135
-
136
- init_audio = (sample_rate, init_audio)
137
-
138
- def progress_callback(callback_info):
139
- global preview_images
140
- denoised = callback_info["denoised"]
141
- current_step = callback_info["i"]
142
- sigma = callback_info["sigma"]
143
-
144
- if (current_step - 1) % preview_every == 0:
145
- if model.pretransform is not None:
146
- denoised = model.pretransform.decode(denoised)
147
- denoised = rearrange(denoised, "b d n -> d (b n)")
148
- denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu()
149
- audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate)
150
- preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})"))
151
-
152
- # If inpainting, send mask args
153
- # This will definitely change in the future
154
- if mask_cropfrom is not None:
155
- mask_args = {
156
- "cropfrom": mask_cropfrom,
157
- "pastefrom": mask_pastefrom,
158
- "pasteto": mask_pasteto,
159
- "maskstart": mask_maskstart,
160
- "maskend": mask_maskend,
161
- "softnessL": mask_softnessL,
162
- "softnessR": mask_softnessR,
163
- "marination": mask_marination,
164
- }
165
- else:
166
- mask_args = None
167
-
168
- # Do the audio generation
169
- audio = generate_diffusion_cond(
170
- model,
171
- conditioning=conditioning,
172
- negative_conditioning=negative_conditioning,
173
- steps=steps,
174
- cfg_scale=cfg_scale,
175
- batch_size=batch_size,
176
- sample_size=input_sample_size,
177
- sample_rate=sample_rate,
178
- seed=seed,
179
- device=device,
180
- sampler_type=sampler_type,
181
- sigma_min=sigma_min,
182
- sigma_max=sigma_max,
183
- init_audio=init_audio,
184
- init_noise_level=init_noise_level,
185
- mask_args = mask_args,
186
- callback = progress_callback if preview_every is not None else None,
187
- scale_phi = cfg_rescale
188
- )
189
-
190
- # Convert to WAV file
191
- audio = rearrange(audio, "b d n -> d (b n)")
192
- audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
193
- torchaudio.save("output.wav", audio, sample_rate)
194
-
195
- # Let's look at a nice spectrogram too
196
- audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate)
197
-
198
- return ("output.wav", [audio_spectrogram, *preview_images])
199
-
200
- def generate_uncond(
201
- steps=250,
202
- seed=-1,
203
- sampler_type="dpmpp-3m-sde",
204
- sigma_min=0.03,
205
- sigma_max=1000,
206
- use_init=False,
207
- init_audio=None,
208
- init_noise_level=1.0,
209
- batch_size=1,
210
- preview_every=None
211
- ):
212
-
213
- global preview_images
214
-
215
- preview_images = []
216
-
217
- if torch.cuda.is_available():
218
- torch.cuda.empty_cache()
219
- gc.collect()
220
-
221
- #Get the device from the model
222
- device = next(model.parameters()).device
223
-
224
- seed = int(seed)
225
-
226
- if not use_init:
227
- init_audio = None
228
-
229
- input_sample_size = sample_size
230
-
231
- if init_audio is not None:
232
- in_sr, init_audio = init_audio
233
- # Turn into torch tensor, converting from int16 to float32
234
- init_audio = torch.from_numpy(init_audio).float().div(32767)
235
-
236
- if init_audio.dim() == 1:
237
- init_audio = init_audio.unsqueeze(0) # [1, n]
238
- elif init_audio.dim() == 2:
239
- init_audio = init_audio.transpose(0, 1) # [n, 2] -> [2, n]
240
-
241
- if in_sr != sample_rate:
242
- resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device)
243
- init_audio = resample_tf(init_audio)
244
-
245
- audio_length = init_audio.shape[-1]
246
-
247
- if audio_length > sample_size:
248
-
249
- input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length
250
-
251
- init_audio = (sample_rate, init_audio)
252
-
253
- def progress_callback(callback_info):
254
- global preview_images
255
- denoised = callback_info["denoised"]
256
- current_step = callback_info["i"]
257
- sigma = callback_info["sigma"]
258
-
259
- if (current_step - 1) % preview_every == 0:
260
-
261
- if model.pretransform is not None:
262
- denoised = model.pretransform.decode(denoised)
263
-
264
- denoised = rearrange(denoised, "b d n -> d (b n)")
265
-
266
- denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu()
267
-
268
- audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate)
269
-
270
- preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})"))
271
-
272
- audio = generate_diffusion_uncond(
273
- model,
274
- steps=steps,
275
- batch_size=batch_size,
276
- sample_size=input_sample_size,
277
- seed=seed,
278
- device=device,
279
- sampler_type=sampler_type,
280
- sigma_min=sigma_min,
281
- sigma_max=sigma_max,
282
- init_audio=init_audio,
283
- init_noise_level=init_noise_level,
284
- callback = progress_callback if preview_every is not None else None
285
- )
286
-
287
- audio = rearrange(audio, "b d n -> d (b n)")
288
-
289
- audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
290
-
291
- torchaudio.save("output.wav", audio, sample_rate)
292
-
293
- audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate)
294
-
295
- return ("output.wav", [audio_spectrogram, *preview_images])
296
-
297
- def generate_lm(
298
- temperature=1.0,
299
- top_p=0.95,
300
- top_k=0,
301
- batch_size=1,
302
- ):
303
-
304
- if torch.cuda.is_available():
305
- torch.cuda.empty_cache()
306
- gc.collect()
307
-
308
- #Get the device from the model
309
- device = next(model.parameters()).device
310
-
311
- audio = model.generate_audio(
312
- batch_size=batch_size,
313
- max_gen_len = sample_size//model.pretransform.downsampling_ratio,
314
- conditioning=None,
315
- temp=temperature,
316
- top_p=top_p,
317
- top_k=top_k,
318
- use_cache=True
319
- )
320
-
321
- audio = rearrange(audio, "b d n -> d (b n)")
322
-
323
- audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
324
-
325
- torchaudio.save("output.wav", audio, sample_rate)
326
-
327
- audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate)
328
-
329
- return ("output.wav", [audio_spectrogram])
330
-
331
-
332
- def create_uncond_sampling_ui(model_config):
333
- generate_button = gr.Button("Generate", variant='primary', scale=1)
334
-
335
- with gr.Row(equal_height=False):
336
- with gr.Column():
337
- with gr.Row():
338
- # Steps slider
339
- steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps")
340
-
341
- with gr.Accordion("Sampler params", open=False):
342
-
343
- # Seed
344
- seed_textbox = gr.Textbox(label="Seed (set to -1 for random seed)", value="-1")
345
-
346
- # Sampler params
347
- with gr.Row():
348
- sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-3m-sde")
349
- sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min")
350
- sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma max")
351
-
352
- with gr.Accordion("Init audio", open=False):
353
- init_audio_checkbox = gr.Checkbox(label="Use init audio")
354
- init_audio_input = gr.Audio(label="Init audio")
355
- init_noise_level_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.01, value=0.1, label="Init noise level")
356
-
357
- with gr.Column():
358
- audio_output = gr.Audio(label="Output audio", interactive=False)
359
- audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False)
360
- send_to_init_button = gr.Button("Send to init audio", scale=1)
361
- send_to_init_button.click(fn=lambda audio: audio, inputs=[audio_output], outputs=[init_audio_input])
362
-
363
- generate_button.click(fn=generate_uncond,
364
- inputs=[
365
- steps_slider,
366
- seed_textbox,
367
- sampler_type_dropdown,
368
- sigma_min_slider,
369
- sigma_max_slider,
370
- init_audio_checkbox,
371
- init_audio_input,
372
- init_noise_level_slider,
373
- ],
374
- outputs=[
375
- audio_output,
376
- audio_spectrogram_output
377
- ],
378
- api_name="generate")
379
-
380
- def create_sampling_ui(model_config, inpainting=False):
381
- with gr.Row():
382
- with gr.Column(scale=6):
383
- prompt = gr.Textbox(show_label=False, placeholder="Prompt")
384
- negative_prompt = gr.Textbox(show_label=False, placeholder="Negative prompt")
385
- generate_button = gr.Button("Generate", variant='primary', scale=1)
386
-
387
- model_conditioning_config = model_config["model"].get("conditioning", None)
388
-
389
- has_seconds_start = False
390
- has_seconds_total = False
391
-
392
- if model_conditioning_config is not None:
393
- for conditioning_config in model_conditioning_config["configs"]:
394
- if conditioning_config["id"] == "seconds_start":
395
- has_seconds_start = True
396
- if conditioning_config["id"] == "seconds_total":
397
- has_seconds_total = True
398
-
399
- with gr.Row(equal_height=False):
400
- with gr.Column():
401
- with gr.Row(visible = has_seconds_start or has_seconds_total):
402
- # Timing controls
403
- seconds_start_slider = gr.Slider(minimum=0, maximum=512, step=1, value=0, label="Seconds start", visible=has_seconds_start)
404
- seconds_total_slider = gr.Slider(minimum=0, maximum=512, step=1, value=sample_size//sample_rate, label="Seconds total", visible=has_seconds_total)
405
-
406
- with gr.Row():
407
- # Steps slider
408
- steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps")
409
-
410
- # Preview Every slider
411
- preview_every_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Preview Every")
412
-
413
- # CFG scale
414
- cfg_scale_slider = gr.Slider(minimum=0.0, maximum=25.0, step=0.1, value=7.0, label="CFG scale")
415
-
416
- with gr.Accordion("Sampler params", open=False):
417
-
418
- # Seed
419
- seed_textbox = gr.Textbox(label="Seed (set to -1 for random seed)", value="-1")
420
-
421
- # Sampler params
422
- with gr.Row():
423
- sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-3m-sde")
424
- sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min")
425
- sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma max")
426
- cfg_rescale_slider = gr.Slider(minimum=0.0, maximum=1, step=0.01, value=0.0, label="CFG rescale amount")
427
-
428
- if inpainting:
429
- # Inpainting Tab
430
- with gr.Accordion("Inpainting", open=False):
431
- sigma_max_slider.maximum=1000
432
-
433
- init_audio_checkbox = gr.Checkbox(label="Do inpainting")
434
- init_audio_input = gr.Audio(label="Init audio")
435
- init_noise_level_slider = gr.Slider(minimum=0.1, maximum=100.0, step=0.1, value=80, label="Init audio noise level", visible=False) # hide this
436
-
437
- mask_cropfrom_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Crop From %")
438
- mask_pastefrom_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Paste From %")
439
- mask_pasteto_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=100, label="Paste To %")
440
-
441
- mask_maskstart_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=50, label="Mask Start %")
442
- mask_maskend_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=100, label="Mask End %")
443
- mask_softnessL_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Softmask Left Crossfade Length %")
444
- mask_softnessR_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Softmask Right Crossfade Length %")
445
- mask_marination_slider = gr.Slider(minimum=0.0, maximum=1, step=0.0001, value=0, label="Marination level", visible=False) # still working on the usefulness of this
446
-
447
- inputs = [prompt,
448
- negative_prompt,
449
- seconds_start_slider,
450
- seconds_total_slider,
451
- cfg_scale_slider,
452
- steps_slider,
453
- preview_every_slider,
454
- seed_textbox,
455
- sampler_type_dropdown,
456
- sigma_min_slider,
457
- sigma_max_slider,
458
- cfg_rescale_slider,
459
- init_audio_checkbox,
460
- init_audio_input,
461
- init_noise_level_slider,
462
- mask_cropfrom_slider,
463
- mask_pastefrom_slider,
464
- mask_pasteto_slider,
465
- mask_maskstart_slider,
466
- mask_maskend_slider,
467
- mask_softnessL_slider,
468
- mask_softnessR_slider,
469
- mask_marination_slider
470
- ]
471
- else:
472
- # Default generation tab
473
- with gr.Accordion("Init audio", open=False):
474
- init_audio_checkbox = gr.Checkbox(label="Use init audio")
475
- init_audio_input = gr.Audio(label="Init audio")
476
- init_noise_level_slider = gr.Slider(minimum=0.1, maximum=100.0, step=0.01, value=0.1, label="Init noise level")
477
-
478
- inputs = [prompt,
479
- negative_prompt,
480
- seconds_start_slider,
481
- seconds_total_slider,
482
- cfg_scale_slider,
483
- steps_slider,
484
- preview_every_slider,
485
- seed_textbox,
486
- sampler_type_dropdown,
487
- sigma_min_slider,
488
- sigma_max_slider,
489
- cfg_rescale_slider,
490
- init_audio_checkbox,
491
- init_audio_input,
492
- init_noise_level_slider
493
- ]
494
-
495
- with gr.Column():
496
- audio_output = gr.Audio(label="Output audio", interactive=False)
497
- audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False)
498
- send_to_init_button = gr.Button("Send to init audio", scale=1)
499
- send_to_init_button.click(fn=lambda audio: audio, inputs=[audio_output], outputs=[init_audio_input])
500
-
501
- generate_button.click(fn=generate_cond,
502
- inputs=inputs,
503
- outputs=[
504
- audio_output,
505
- audio_spectrogram_output
506
- ],
507
- api_name="generate")
508
-
509
-
510
- def create_txt2audio_ui(model_config):
511
- with gr.Blocks() as ui:
512
- with gr.Tab("Generation"):
513
- create_sampling_ui(model_config)
514
- with gr.Tab("Inpainting"):
515
- create_sampling_ui(model_config, inpainting=True)
516
- return ui
517
-
518
- def create_diffusion_uncond_ui(model_config):
519
- with gr.Blocks() as ui:
520
- create_uncond_sampling_ui(model_config)
521
-
522
- return ui
523
-
524
- def autoencoder_process(audio, latent_noise, n_quantizers):
525
- if torch.cuda.is_available():
526
- torch.cuda.empty_cache()
527
- gc.collect()
528
-
529
- #Get the device from the model
530
- device = next(model.parameters()).device
531
-
532
- in_sr, audio = audio
533
-
534
- audio = torch.from_numpy(audio).float().div(32767).to(device)
535
-
536
- if audio.dim() == 1:
537
- audio = audio.unsqueeze(0)
538
- else:
539
- audio = audio.transpose(0, 1)
540
-
541
- audio = model.preprocess_audio_for_encoder(audio, in_sr)
542
- # Note: If you need to do chunked encoding, to reduce VRAM,
543
- # then add these arguments to encode_audio and decode_audio: chunked=True, overlap=32, chunk_size=128
544
- # To turn it off, do chunked=False
545
- # Optimal overlap and chunk_size values will depend on the model.
546
- # See encode_audio & decode_audio in autoencoders.py for more info
547
- # Get dtype of model
548
- dtype = next(model.parameters()).dtype
549
-
550
- audio = audio.to(dtype)
551
-
552
- if n_quantizers > 0:
553
- latents = model.encode_audio(audio, chunked=False, n_quantizers=n_quantizers)
554
- else:
555
- latents = model.encode_audio(audio, chunked=False)
556
-
557
- if latent_noise > 0:
558
- latents = latents + torch.randn_like(latents) * latent_noise
559
-
560
- audio = model.decode_audio(latents, chunked=False)
561
-
562
- audio = rearrange(audio, "b d n -> d (b n)")
563
-
564
- audio = audio.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
565
-
566
- torchaudio.save("output.wav", audio, sample_rate)
567
-
568
- return "output.wav"
569
-
570
- def create_autoencoder_ui(model_config):
571
-
572
- is_dac_rvq = "model" in model_config and "bottleneck" in model_config["model"] and model_config["model"]["bottleneck"]["type"] in ["dac_rvq","dac_rvq_vae"]
573
-
574
- if is_dac_rvq:
575
- n_quantizers = model_config["model"]["bottleneck"]["config"]["n_codebooks"]
576
- else:
577
- n_quantizers = 0
578
-
579
- with gr.Blocks() as ui:
580
- input_audio = gr.Audio(label="Input audio")
581
- output_audio = gr.Audio(label="Output audio", interactive=False)
582
- n_quantizers_slider = gr.Slider(minimum=1, maximum=n_quantizers, step=1, value=n_quantizers, label="# quantizers", visible=is_dac_rvq)
583
- latent_noise_slider = gr.Slider(minimum=0.0, maximum=10.0, step=0.001, value=0.0, label="Add latent noise")
584
- process_button = gr.Button("Process", variant='primary', scale=1)
585
- process_button.click(fn=autoencoder_process, inputs=[input_audio, latent_noise_slider, n_quantizers_slider], outputs=output_audio, api_name="process")
586
-
587
- return ui
588
-
589
- def diffusion_prior_process(audio, steps, sampler_type, sigma_min, sigma_max):
590
-
591
- if torch.cuda.is_available():
592
- torch.cuda.empty_cache()
593
- gc.collect()
594
-
595
- #Get the device from the model
596
- device = next(model.parameters()).device
597
-
598
- in_sr, audio = audio
599
-
600
- audio = torch.from_numpy(audio).float().div(32767).to(device)
601
-
602
- if audio.dim() == 1:
603
- audio = audio.unsqueeze(0) # [1, n]
604
- elif audio.dim() == 2:
605
- audio = audio.transpose(0, 1) # [n, 2] -> [2, n]
606
-
607
- audio = audio.unsqueeze(0)
608
-
609
- audio = model.stereoize(audio, in_sr, steps, sampler_kwargs={"sampler_type": sampler_type, "sigma_min": sigma_min, "sigma_max": sigma_max})
610
-
611
- audio = rearrange(audio, "b d n -> d (b n)")
612
-
613
- audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
614
-
615
- torchaudio.save("output.wav", audio, sample_rate)
616
-
617
- return "output.wav"
618
-
619
- def create_diffusion_prior_ui(model_config):
620
- with gr.Blocks() as ui:
621
- input_audio = gr.Audio(label="Input audio")
622
- output_audio = gr.Audio(label="Output audio", interactive=False)
623
- # Sampler params
624
- with gr.Row():
625
- steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps")
626
- sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-3m-sde")
627
- sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min")
628
- sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma max")
629
- process_button = gr.Button("Process", variant='primary', scale=1)
630
- process_button.click(fn=diffusion_prior_process, inputs=[input_audio, steps_slider, sampler_type_dropdown, sigma_min_slider, sigma_max_slider], outputs=output_audio, api_name="process")
631
-
632
- return ui
633
-
634
- def create_lm_ui(model_config):
635
- with gr.Blocks() as ui:
636
- output_audio = gr.Audio(label="Output audio", interactive=False)
637
- audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False)
638
-
639
- # Sampling params
640
- with gr.Row():
641
- temperature_slider = gr.Slider(minimum=0, maximum=5, step=0.01, value=1.0, label="Temperature")
642
- top_p_slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.95, label="Top p")
643
- top_k_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Top k")
644
-
645
- generate_button = gr.Button("Generate", variant='primary', scale=1)
646
- generate_button.click(
647
- fn=generate_lm,
648
- inputs=[
649
- temperature_slider,
650
- top_p_slider,
651
- top_k_slider
652
- ],
653
- outputs=[output_audio, audio_spectrogram_output],
654
- api_name="generate"
655
- )
656
-
657
- return ui
658
-
659
- def create_ui(model_config_path=None, ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, model_half=False):
660
-
661
- assert (pretrained_name is not None) ^ (model_config_path is not None and ckpt_path is not None), "Must specify either pretrained name or provide a model config and checkpoint, but not both"
662
-
663
- if model_config_path is not None:
664
- # Load config from json file
665
- with open(model_config_path) as f:
666
- model_config = json.load(f)
667
- else:
668
- model_config = None
669
-
670
- try:
671
- has_mps = platform.system() == "Darwin" and torch.backends.mps.is_available()
672
- except Exception:
673
- # In case this version of Torch doesn't even have `torch.backends.mps`...
674
- has_mps = False
675
-
676
- if has_mps:
677
- device = torch.device("mps")
678
- elif torch.cuda.is_available():
679
- device = torch.device("cuda")
680
- else:
681
- device = torch.device("cpu")
682
-
683
- print("Using device:", device)
684
-
685
- _, model_config = load_model(model_config, ckpt_path, pretrained_name=pretrained_name, pretransform_ckpt_path=pretransform_ckpt_path, model_half=model_half, device=device)
686
-
687
- model_type = model_config["model_type"]
688
-
689
- if model_type == "diffusion_cond":
690
- ui = create_txt2audio_ui(model_config)
691
- elif model_type == "diffusion_uncond":
692
- ui = create_diffusion_uncond_ui(model_config)
693
- elif model_type == "autoencoder" or model_type == "diffusion_autoencoder":
694
- ui = create_autoencoder_ui(model_config)
695
- elif model_type == "diffusion_prior":
696
- ui = create_diffusion_prior_ui(model_config)
697
- elif model_type == "lm":
698
- ui = create_lm_ui(model_config)
699
-
700
- return ui
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/models/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .factory import create_model_from_config, create_model_from_config_path
 
 
stable/build/lib/stable_audio_tools/models/adp.py DELETED
@@ -1,1588 +0,0 @@
1
- # Copied and modified from https://github.com/archinetai/audio-diffusion-pytorch/blob/v0.0.94/audio_diffusion_pytorch/modules.py under MIT License
2
- # License can be found in LICENSES/LICENSE_ADP.txt
3
-
4
- import math
5
- from inspect import isfunction
6
- from math import ceil, floor, log, pi, log2
7
- from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
8
- from packaging import version
9
-
10
- import torch
11
- import torch.nn as nn
12
- from einops import rearrange, reduce, repeat
13
- from einops.layers.torch import Rearrange
14
- from einops_exts import rearrange_many
15
- from torch import Tensor, einsum
16
- from torch.backends.cuda import sdp_kernel
17
- from torch.nn import functional as F
18
- from dac.nn.layers import Snake1d
19
-
20
- """
21
- Utils
22
- """
23
-
24
-
25
- class ConditionedSequential(nn.Module):
26
- def __init__(self, *modules):
27
- super().__init__()
28
- self.module_list = nn.ModuleList(*modules)
29
-
30
- def forward(self, x: Tensor, mapping: Optional[Tensor] = None):
31
- for module in self.module_list:
32
- x = module(x, mapping)
33
- return x
34
-
35
- T = TypeVar("T")
36
-
37
- def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
38
- if exists(val):
39
- return val
40
- return d() if isfunction(d) else d
41
-
42
- def exists(val: Optional[T]) -> T:
43
- return val is not None
44
-
45
- def closest_power_2(x: float) -> int:
46
- exponent = log2(x)
47
- distance_fn = lambda z: abs(x - 2 ** z) # noqa
48
- exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
49
- return 2 ** int(exponent_closest)
50
-
51
- def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
52
- return_dicts: Tuple[Dict, Dict] = ({}, {})
53
- for key in d.keys():
54
- no_prefix = int(not key.startswith(prefix))
55
- return_dicts[no_prefix][key] = d[key]
56
- return return_dicts
57
-
58
- def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
59
- kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
60
- if keep_prefix:
61
- return kwargs_with_prefix, kwargs
62
- kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
63
- return kwargs_no_prefix, kwargs
64
-
65
- """
66
- Convolutional Blocks
67
- """
68
- import typing as tp
69
-
70
- # Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conv.py under MIT License
71
- # License available in LICENSES/LICENSE_META.txt
72
-
73
- def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
74
- padding_total: int = 0) -> int:
75
- """See `pad_for_conv1d`."""
76
- length = x.shape[-1]
77
- n_frames = (length - kernel_size + padding_total) / stride + 1
78
- ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
79
- return ideal_length - length
80
-
81
-
82
- def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
83
- """Pad for a convolution to make sure that the last window is full.
84
- Extra padding is added at the end. This is required to ensure that we can rebuild
85
- an output of the same length, as otherwise, even with padding, some time steps
86
- might get removed.
87
- For instance, with total padding = 4, kernel size = 4, stride = 2:
88
- 0 0 1 2 3 4 5 0 0 # (0s are padding)
89
- 1 2 3 # (output frames of a convolution, last 0 is never used)
90
- 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
91
- 1 2 3 4 # once you removed padding, we are missing one time step !
92
- """
93
- extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
94
- return F.pad(x, (0, extra_padding))
95
-
96
-
97
- def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
98
- """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
99
- If this is the case, we insert extra 0 padding to the right before the reflection happen.
100
- """
101
- length = x.shape[-1]
102
- padding_left, padding_right = paddings
103
- assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
104
- if mode == 'reflect':
105
- max_pad = max(padding_left, padding_right)
106
- extra_pad = 0
107
- if length <= max_pad:
108
- extra_pad = max_pad - length + 1
109
- x = F.pad(x, (0, extra_pad))
110
- padded = F.pad(x, paddings, mode, value)
111
- end = padded.shape[-1] - extra_pad
112
- return padded[..., :end]
113
- else:
114
- return F.pad(x, paddings, mode, value)
115
-
116
-
117
- def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
118
- """Remove padding from x, handling properly zero padding. Only for 1d!"""
119
- padding_left, padding_right = paddings
120
- assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
121
- assert (padding_left + padding_right) <= x.shape[-1]
122
- end = x.shape[-1] - padding_right
123
- return x[..., padding_left: end]
124
-
125
-
126
- class Conv1d(nn.Conv1d):
127
- def __init__(self, *args, **kwargs):
128
- super().__init__(*args, **kwargs)
129
-
130
- def forward(self, x: Tensor, causal=False) -> Tensor:
131
- kernel_size = self.kernel_size[0]
132
- stride = self.stride[0]
133
- dilation = self.dilation[0]
134
- kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
135
- padding_total = kernel_size - stride
136
- extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
137
- if causal:
138
- # Left padding for causal
139
- x = pad1d(x, (padding_total, extra_padding))
140
- else:
141
- # Asymmetric padding required for odd strides
142
- padding_right = padding_total // 2
143
- padding_left = padding_total - padding_right
144
- x = pad1d(x, (padding_left, padding_right + extra_padding))
145
- return super().forward(x)
146
-
147
- class ConvTranspose1d(nn.ConvTranspose1d):
148
- def __init__(self, *args, **kwargs):
149
- super().__init__(*args, **kwargs)
150
-
151
- def forward(self, x: Tensor, causal=False) -> Tensor:
152
- kernel_size = self.kernel_size[0]
153
- stride = self.stride[0]
154
- padding_total = kernel_size - stride
155
-
156
- y = super().forward(x)
157
-
158
- # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
159
- # removed at the very end, when keeping only the right length for the output,
160
- # as removing it here would require also passing the length at the matching layer
161
- # in the encoder.
162
- if causal:
163
- padding_right = ceil(padding_total)
164
- padding_left = padding_total - padding_right
165
- y = unpad1d(y, (padding_left, padding_right))
166
- else:
167
- # Asymmetric padding required for odd strides
168
- padding_right = padding_total // 2
169
- padding_left = padding_total - padding_right
170
- y = unpad1d(y, (padding_left, padding_right))
171
- return y
172
-
173
-
174
- def Downsample1d(
175
- in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
176
- ) -> nn.Module:
177
- assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
178
-
179
- return Conv1d(
180
- in_channels=in_channels,
181
- out_channels=out_channels,
182
- kernel_size=factor * kernel_multiplier + 1,
183
- stride=factor
184
- )
185
-
186
-
187
- def Upsample1d(
188
- in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
189
- ) -> nn.Module:
190
-
191
- if factor == 1:
192
- return Conv1d(
193
- in_channels=in_channels, out_channels=out_channels, kernel_size=3
194
- )
195
-
196
- if use_nearest:
197
- return nn.Sequential(
198
- nn.Upsample(scale_factor=factor, mode="nearest"),
199
- Conv1d(
200
- in_channels=in_channels,
201
- out_channels=out_channels,
202
- kernel_size=3
203
- ),
204
- )
205
- else:
206
- return ConvTranspose1d(
207
- in_channels=in_channels,
208
- out_channels=out_channels,
209
- kernel_size=factor * 2,
210
- stride=factor
211
- )
212
-
213
-
214
- class ConvBlock1d(nn.Module):
215
- def __init__(
216
- self,
217
- in_channels: int,
218
- out_channels: int,
219
- *,
220
- kernel_size: int = 3,
221
- stride: int = 1,
222
- dilation: int = 1,
223
- num_groups: int = 8,
224
- use_norm: bool = True,
225
- use_snake: bool = False
226
- ) -> None:
227
- super().__init__()
228
-
229
- self.groupnorm = (
230
- nn.GroupNorm(num_groups=num_groups, num_channels=in_channels)
231
- if use_norm
232
- else nn.Identity()
233
- )
234
-
235
- if use_snake:
236
- self.activation = Snake1d(in_channels)
237
- else:
238
- self.activation = nn.SiLU()
239
-
240
- self.project = Conv1d(
241
- in_channels=in_channels,
242
- out_channels=out_channels,
243
- kernel_size=kernel_size,
244
- stride=stride,
245
- dilation=dilation,
246
- )
247
-
248
- def forward(
249
- self, x: Tensor, scale_shift: Optional[Tuple[Tensor, Tensor]] = None, causal=False
250
- ) -> Tensor:
251
- x = self.groupnorm(x)
252
- if exists(scale_shift):
253
- scale, shift = scale_shift
254
- x = x * (scale + 1) + shift
255
- x = self.activation(x)
256
- return self.project(x, causal=causal)
257
-
258
-
259
- class MappingToScaleShift(nn.Module):
260
- def __init__(
261
- self,
262
- features: int,
263
- channels: int,
264
- ):
265
- super().__init__()
266
-
267
- self.to_scale_shift = nn.Sequential(
268
- nn.SiLU(),
269
- nn.Linear(in_features=features, out_features=channels * 2),
270
- )
271
-
272
- def forward(self, mapping: Tensor) -> Tuple[Tensor, Tensor]:
273
- scale_shift = self.to_scale_shift(mapping)
274
- scale_shift = rearrange(scale_shift, "b c -> b c 1")
275
- scale, shift = scale_shift.chunk(2, dim=1)
276
- return scale, shift
277
-
278
-
279
- class ResnetBlock1d(nn.Module):
280
- def __init__(
281
- self,
282
- in_channels: int,
283
- out_channels: int,
284
- *,
285
- kernel_size: int = 3,
286
- stride: int = 1,
287
- dilation: int = 1,
288
- use_norm: bool = True,
289
- use_snake: bool = False,
290
- num_groups: int = 8,
291
- context_mapping_features: Optional[int] = None,
292
- ) -> None:
293
- super().__init__()
294
-
295
- self.use_mapping = exists(context_mapping_features)
296
-
297
- self.block1 = ConvBlock1d(
298
- in_channels=in_channels,
299
- out_channels=out_channels,
300
- kernel_size=kernel_size,
301
- stride=stride,
302
- dilation=dilation,
303
- use_norm=use_norm,
304
- num_groups=num_groups,
305
- use_snake=use_snake
306
- )
307
-
308
- if self.use_mapping:
309
- assert exists(context_mapping_features)
310
- self.to_scale_shift = MappingToScaleShift(
311
- features=context_mapping_features, channels=out_channels
312
- )
313
-
314
- self.block2 = ConvBlock1d(
315
- in_channels=out_channels,
316
- out_channels=out_channels,
317
- use_norm=use_norm,
318
- num_groups=num_groups,
319
- use_snake=use_snake
320
- )
321
-
322
- self.to_out = (
323
- Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
324
- if in_channels != out_channels
325
- else nn.Identity()
326
- )
327
-
328
- def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
329
- assert_message = "context mapping required if context_mapping_features > 0"
330
- assert not (self.use_mapping ^ exists(mapping)), assert_message
331
-
332
- h = self.block1(x, causal=causal)
333
-
334
- scale_shift = None
335
- if self.use_mapping:
336
- scale_shift = self.to_scale_shift(mapping)
337
-
338
- h = self.block2(h, scale_shift=scale_shift, causal=causal)
339
-
340
- return h + self.to_out(x)
341
-
342
-
343
- class Patcher(nn.Module):
344
- def __init__(
345
- self,
346
- in_channels: int,
347
- out_channels: int,
348
- patch_size: int,
349
- context_mapping_features: Optional[int] = None,
350
- use_snake: bool = False,
351
- ):
352
- super().__init__()
353
- assert_message = f"out_channels must be divisible by patch_size ({patch_size})"
354
- assert out_channels % patch_size == 0, assert_message
355
- self.patch_size = patch_size
356
-
357
- self.block = ResnetBlock1d(
358
- in_channels=in_channels,
359
- out_channels=out_channels // patch_size,
360
- num_groups=1,
361
- context_mapping_features=context_mapping_features,
362
- use_snake=use_snake
363
- )
364
-
365
- def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
366
- x = self.block(x, mapping, causal=causal)
367
- x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size)
368
- return x
369
-
370
-
371
- class Unpatcher(nn.Module):
372
- def __init__(
373
- self,
374
- in_channels: int,
375
- out_channels: int,
376
- patch_size: int,
377
- context_mapping_features: Optional[int] = None,
378
- use_snake: bool = False
379
- ):
380
- super().__init__()
381
- assert_message = f"in_channels must be divisible by patch_size ({patch_size})"
382
- assert in_channels % patch_size == 0, assert_message
383
- self.patch_size = patch_size
384
-
385
- self.block = ResnetBlock1d(
386
- in_channels=in_channels // patch_size,
387
- out_channels=out_channels,
388
- num_groups=1,
389
- context_mapping_features=context_mapping_features,
390
- use_snake=use_snake
391
- )
392
-
393
- def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
394
- x = rearrange(x, " b (c p) l -> b c (l p) ", p=self.patch_size)
395
- x = self.block(x, mapping, causal=causal)
396
- return x
397
-
398
-
399
- """
400
- Attention Components
401
- """
402
- def FeedForward(features: int, multiplier: int) -> nn.Module:
403
- mid_features = features * multiplier
404
- return nn.Sequential(
405
- nn.Linear(in_features=features, out_features=mid_features),
406
- nn.GELU(),
407
- nn.Linear(in_features=mid_features, out_features=features),
408
- )
409
-
410
- def add_mask(sim: Tensor, mask: Tensor) -> Tensor:
411
- b, ndim = sim.shape[0], mask.ndim
412
- if ndim == 3:
413
- mask = rearrange(mask, "b n m -> b 1 n m")
414
- if ndim == 2:
415
- mask = repeat(mask, "n m -> b 1 n m", b=b)
416
- max_neg_value = -torch.finfo(sim.dtype).max
417
- sim = sim.masked_fill(~mask, max_neg_value)
418
- return sim
419
-
420
- def causal_mask(q: Tensor, k: Tensor) -> Tensor:
421
- b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device
422
- mask = ~torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1)
423
- mask = repeat(mask, "n m -> b n m", b=b)
424
- return mask
425
-
426
- class AttentionBase(nn.Module):
427
- def __init__(
428
- self,
429
- features: int,
430
- *,
431
- head_features: int,
432
- num_heads: int,
433
- out_features: Optional[int] = None,
434
- ):
435
- super().__init__()
436
- self.scale = head_features**-0.5
437
- self.num_heads = num_heads
438
- mid_features = head_features * num_heads
439
- out_features = default(out_features, features)
440
-
441
- self.to_out = nn.Linear(
442
- in_features=mid_features, out_features=out_features
443
- )
444
-
445
- self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
446
-
447
- if not self.use_flash:
448
- return
449
-
450
- device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
451
-
452
- if device_properties.major == 8 and device_properties.minor == 0:
453
- # Use flash attention for A100 GPUs
454
- self.sdp_kernel_config = (True, False, False)
455
- else:
456
- # Don't use flash attention for other GPUs
457
- self.sdp_kernel_config = (False, True, True)
458
-
459
- def forward(
460
- self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool = False
461
- ) -> Tensor:
462
- # Split heads
463
- q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
464
-
465
- if not self.use_flash:
466
- if is_causal and not mask:
467
- # Mask out future tokens for causal attention
468
- mask = causal_mask(q, k)
469
-
470
- # Compute similarity matrix and add eventual mask
471
- sim = einsum("... n d, ... m d -> ... n m", q, k) * self.scale
472
- sim = add_mask(sim, mask) if exists(mask) else sim
473
-
474
- # Get attention matrix with softmax
475
- attn = sim.softmax(dim=-1, dtype=torch.float32)
476
-
477
- # Compute values
478
- out = einsum("... n m, ... m d -> ... n d", attn, v)
479
- else:
480
- with sdp_kernel(*self.sdp_kernel_config):
481
- out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=is_causal)
482
-
483
- out = rearrange(out, "b h n d -> b n (h d)")
484
- return self.to_out(out)
485
-
486
- class Attention(nn.Module):
487
- def __init__(
488
- self,
489
- features: int,
490
- *,
491
- head_features: int,
492
- num_heads: int,
493
- out_features: Optional[int] = None,
494
- context_features: Optional[int] = None,
495
- causal: bool = False,
496
- ):
497
- super().__init__()
498
- self.context_features = context_features
499
- self.causal = causal
500
- mid_features = head_features * num_heads
501
- context_features = default(context_features, features)
502
-
503
- self.norm = nn.LayerNorm(features)
504
- self.norm_context = nn.LayerNorm(context_features)
505
- self.to_q = nn.Linear(
506
- in_features=features, out_features=mid_features, bias=False
507
- )
508
- self.to_kv = nn.Linear(
509
- in_features=context_features, out_features=mid_features * 2, bias=False
510
- )
511
- self.attention = AttentionBase(
512
- features,
513
- num_heads=num_heads,
514
- head_features=head_features,
515
- out_features=out_features,
516
- )
517
-
518
- def forward(
519
- self,
520
- x: Tensor, # [b, n, c]
521
- context: Optional[Tensor] = None, # [b, m, d]
522
- context_mask: Optional[Tensor] = None, # [b, m], false is masked,
523
- causal: Optional[bool] = False,
524
- ) -> Tensor:
525
- assert_message = "You must provide a context when using context_features"
526
- assert not self.context_features or exists(context), assert_message
527
- # Use context if provided
528
- context = default(context, x)
529
- # Normalize then compute q from input and k,v from context
530
- x, context = self.norm(x), self.norm_context(context)
531
-
532
- q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
533
-
534
- if exists(context_mask):
535
- # Mask out cross-attention for padding tokens
536
- mask = repeat(context_mask, "b m -> b m d", d=v.shape[-1])
537
- k, v = k * mask, v * mask
538
-
539
- # Compute and return attention
540
- return self.attention(q, k, v, is_causal=self.causal or causal)
541
-
542
-
543
- def FeedForward(features: int, multiplier: int) -> nn.Module:
544
- mid_features = features * multiplier
545
- return nn.Sequential(
546
- nn.Linear(in_features=features, out_features=mid_features),
547
- nn.GELU(),
548
- nn.Linear(in_features=mid_features, out_features=features),
549
- )
550
-
551
- """
552
- Transformer Blocks
553
- """
554
-
555
-
556
- class TransformerBlock(nn.Module):
557
- def __init__(
558
- self,
559
- features: int,
560
- num_heads: int,
561
- head_features: int,
562
- multiplier: int,
563
- context_features: Optional[int] = None,
564
- ):
565
- super().__init__()
566
-
567
- self.use_cross_attention = exists(context_features) and context_features > 0
568
-
569
- self.attention = Attention(
570
- features=features,
571
- num_heads=num_heads,
572
- head_features=head_features
573
- )
574
-
575
- if self.use_cross_attention:
576
- self.cross_attention = Attention(
577
- features=features,
578
- num_heads=num_heads,
579
- head_features=head_features,
580
- context_features=context_features
581
- )
582
-
583
- self.feed_forward = FeedForward(features=features, multiplier=multiplier)
584
-
585
- def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal: Optional[bool] = False) -> Tensor:
586
- x = self.attention(x, causal=causal) + x
587
- if self.use_cross_attention:
588
- x = self.cross_attention(x, context=context, context_mask=context_mask) + x
589
- x = self.feed_forward(x) + x
590
- return x
591
-
592
-
593
- """
594
- Transformers
595
- """
596
-
597
-
598
- class Transformer1d(nn.Module):
599
- def __init__(
600
- self,
601
- num_layers: int,
602
- channels: int,
603
- num_heads: int,
604
- head_features: int,
605
- multiplier: int,
606
- context_features: Optional[int] = None,
607
- ):
608
- super().__init__()
609
-
610
- self.to_in = nn.Sequential(
611
- nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True),
612
- Conv1d(
613
- in_channels=channels,
614
- out_channels=channels,
615
- kernel_size=1,
616
- ),
617
- Rearrange("b c t -> b t c"),
618
- )
619
-
620
- self.blocks = nn.ModuleList(
621
- [
622
- TransformerBlock(
623
- features=channels,
624
- head_features=head_features,
625
- num_heads=num_heads,
626
- multiplier=multiplier,
627
- context_features=context_features,
628
- )
629
- for i in range(num_layers)
630
- ]
631
- )
632
-
633
- self.to_out = nn.Sequential(
634
- Rearrange("b t c -> b c t"),
635
- Conv1d(
636
- in_channels=channels,
637
- out_channels=channels,
638
- kernel_size=1,
639
- ),
640
- )
641
-
642
- def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal=False) -> Tensor:
643
- x = self.to_in(x)
644
- for block in self.blocks:
645
- x = block(x, context=context, context_mask=context_mask, causal=causal)
646
- x = self.to_out(x)
647
- return x
648
-
649
-
650
- """
651
- Time Embeddings
652
- """
653
-
654
-
655
- class SinusoidalEmbedding(nn.Module):
656
- def __init__(self, dim: int):
657
- super().__init__()
658
- self.dim = dim
659
-
660
- def forward(self, x: Tensor) -> Tensor:
661
- device, half_dim = x.device, self.dim // 2
662
- emb = torch.tensor(log(10000) / (half_dim - 1), device=device)
663
- emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
664
- emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
665
- return torch.cat((emb.sin(), emb.cos()), dim=-1)
666
-
667
-
668
- class LearnedPositionalEmbedding(nn.Module):
669
- """Used for continuous time"""
670
-
671
- def __init__(self, dim: int):
672
- super().__init__()
673
- assert (dim % 2) == 0
674
- half_dim = dim // 2
675
- self.weights = nn.Parameter(torch.randn(half_dim))
676
-
677
- def forward(self, x: Tensor) -> Tensor:
678
- x = rearrange(x, "b -> b 1")
679
- freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
680
- fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
681
- fouriered = torch.cat((x, fouriered), dim=-1)
682
- return fouriered
683
-
684
-
685
- def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
686
- return nn.Sequential(
687
- LearnedPositionalEmbedding(dim),
688
- nn.Linear(in_features=dim + 1, out_features=out_features),
689
- )
690
-
691
-
692
- """
693
- Encoder/Decoder Components
694
- """
695
-
696
-
697
- class DownsampleBlock1d(nn.Module):
698
- def __init__(
699
- self,
700
- in_channels: int,
701
- out_channels: int,
702
- *,
703
- factor: int,
704
- num_groups: int,
705
- num_layers: int,
706
- kernel_multiplier: int = 2,
707
- use_pre_downsample: bool = True,
708
- use_skip: bool = False,
709
- use_snake: bool = False,
710
- extract_channels: int = 0,
711
- context_channels: int = 0,
712
- num_transformer_blocks: int = 0,
713
- attention_heads: Optional[int] = None,
714
- attention_features: Optional[int] = None,
715
- attention_multiplier: Optional[int] = None,
716
- context_mapping_features: Optional[int] = None,
717
- context_embedding_features: Optional[int] = None,
718
- ):
719
- super().__init__()
720
- self.use_pre_downsample = use_pre_downsample
721
- self.use_skip = use_skip
722
- self.use_transformer = num_transformer_blocks > 0
723
- self.use_extract = extract_channels > 0
724
- self.use_context = context_channels > 0
725
-
726
- channels = out_channels if use_pre_downsample else in_channels
727
-
728
- self.downsample = Downsample1d(
729
- in_channels=in_channels,
730
- out_channels=out_channels,
731
- factor=factor,
732
- kernel_multiplier=kernel_multiplier,
733
- )
734
-
735
- self.blocks = nn.ModuleList(
736
- [
737
- ResnetBlock1d(
738
- in_channels=channels + context_channels if i == 0 else channels,
739
- out_channels=channels,
740
- num_groups=num_groups,
741
- context_mapping_features=context_mapping_features,
742
- use_snake=use_snake
743
- )
744
- for i in range(num_layers)
745
- ]
746
- )
747
-
748
- if self.use_transformer:
749
- assert (
750
- (exists(attention_heads) or exists(attention_features))
751
- and exists(attention_multiplier)
752
- )
753
-
754
- if attention_features is None and attention_heads is not None:
755
- attention_features = channels // attention_heads
756
-
757
- if attention_heads is None and attention_features is not None:
758
- attention_heads = channels // attention_features
759
-
760
- self.transformer = Transformer1d(
761
- num_layers=num_transformer_blocks,
762
- channels=channels,
763
- num_heads=attention_heads,
764
- head_features=attention_features,
765
- multiplier=attention_multiplier,
766
- context_features=context_embedding_features
767
- )
768
-
769
- if self.use_extract:
770
- num_extract_groups = min(num_groups, extract_channels)
771
- self.to_extracted = ResnetBlock1d(
772
- in_channels=out_channels,
773
- out_channels=extract_channels,
774
- num_groups=num_extract_groups,
775
- use_snake=use_snake
776
- )
777
-
778
- def forward(
779
- self,
780
- x: Tensor,
781
- *,
782
- mapping: Optional[Tensor] = None,
783
- channels: Optional[Tensor] = None,
784
- embedding: Optional[Tensor] = None,
785
- embedding_mask: Optional[Tensor] = None,
786
- causal: Optional[bool] = False
787
- ) -> Union[Tuple[Tensor, List[Tensor]], Tensor]:
788
-
789
- if self.use_pre_downsample:
790
- x = self.downsample(x)
791
-
792
- if self.use_context and exists(channels):
793
- x = torch.cat([x, channels], dim=1)
794
-
795
- skips = []
796
- for block in self.blocks:
797
- x = block(x, mapping=mapping, causal=causal)
798
- skips += [x] if self.use_skip else []
799
-
800
- if self.use_transformer:
801
- x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
802
- skips += [x] if self.use_skip else []
803
-
804
- if not self.use_pre_downsample:
805
- x = self.downsample(x)
806
-
807
- if self.use_extract:
808
- extracted = self.to_extracted(x)
809
- return x, extracted
810
-
811
- return (x, skips) if self.use_skip else x
812
-
813
-
814
- class UpsampleBlock1d(nn.Module):
815
- def __init__(
816
- self,
817
- in_channels: int,
818
- out_channels: int,
819
- *,
820
- factor: int,
821
- num_layers: int,
822
- num_groups: int,
823
- use_nearest: bool = False,
824
- use_pre_upsample: bool = False,
825
- use_skip: bool = False,
826
- use_snake: bool = False,
827
- skip_channels: int = 0,
828
- use_skip_scale: bool = False,
829
- extract_channels: int = 0,
830
- num_transformer_blocks: int = 0,
831
- attention_heads: Optional[int] = None,
832
- attention_features: Optional[int] = None,
833
- attention_multiplier: Optional[int] = None,
834
- context_mapping_features: Optional[int] = None,
835
- context_embedding_features: Optional[int] = None,
836
- ):
837
- super().__init__()
838
-
839
- self.use_extract = extract_channels > 0
840
- self.use_pre_upsample = use_pre_upsample
841
- self.use_transformer = num_transformer_blocks > 0
842
- self.use_skip = use_skip
843
- self.skip_scale = 2 ** -0.5 if use_skip_scale else 1.0
844
-
845
- channels = out_channels if use_pre_upsample else in_channels
846
-
847
- self.blocks = nn.ModuleList(
848
- [
849
- ResnetBlock1d(
850
- in_channels=channels + skip_channels,
851
- out_channels=channels,
852
- num_groups=num_groups,
853
- context_mapping_features=context_mapping_features,
854
- use_snake=use_snake
855
- )
856
- for _ in range(num_layers)
857
- ]
858
- )
859
-
860
- if self.use_transformer:
861
- assert (
862
- (exists(attention_heads) or exists(attention_features))
863
- and exists(attention_multiplier)
864
- )
865
-
866
- if attention_features is None and attention_heads is not None:
867
- attention_features = channels // attention_heads
868
-
869
- if attention_heads is None and attention_features is not None:
870
- attention_heads = channels // attention_features
871
-
872
- self.transformer = Transformer1d(
873
- num_layers=num_transformer_blocks,
874
- channels=channels,
875
- num_heads=attention_heads,
876
- head_features=attention_features,
877
- multiplier=attention_multiplier,
878
- context_features=context_embedding_features,
879
- )
880
-
881
- self.upsample = Upsample1d(
882
- in_channels=in_channels,
883
- out_channels=out_channels,
884
- factor=factor,
885
- use_nearest=use_nearest,
886
- )
887
-
888
- if self.use_extract:
889
- num_extract_groups = min(num_groups, extract_channels)
890
- self.to_extracted = ResnetBlock1d(
891
- in_channels=out_channels,
892
- out_channels=extract_channels,
893
- num_groups=num_extract_groups,
894
- use_snake=use_snake
895
- )
896
-
897
- def add_skip(self, x: Tensor, skip: Tensor) -> Tensor:
898
- return torch.cat([x, skip * self.skip_scale], dim=1)
899
-
900
- def forward(
901
- self,
902
- x: Tensor,
903
- *,
904
- skips: Optional[List[Tensor]] = None,
905
- mapping: Optional[Tensor] = None,
906
- embedding: Optional[Tensor] = None,
907
- embedding_mask: Optional[Tensor] = None,
908
- causal: Optional[bool] = False
909
- ) -> Union[Tuple[Tensor, Tensor], Tensor]:
910
-
911
- if self.use_pre_upsample:
912
- x = self.upsample(x)
913
-
914
- for block in self.blocks:
915
- x = self.add_skip(x, skip=skips.pop()) if exists(skips) else x
916
- x = block(x, mapping=mapping, causal=causal)
917
-
918
- if self.use_transformer:
919
- x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
920
-
921
- if not self.use_pre_upsample:
922
- x = self.upsample(x)
923
-
924
- if self.use_extract:
925
- extracted = self.to_extracted(x)
926
- return x, extracted
927
-
928
- return x
929
-
930
-
931
- class BottleneckBlock1d(nn.Module):
932
- def __init__(
933
- self,
934
- channels: int,
935
- *,
936
- num_groups: int,
937
- num_transformer_blocks: int = 0,
938
- attention_heads: Optional[int] = None,
939
- attention_features: Optional[int] = None,
940
- attention_multiplier: Optional[int] = None,
941
- context_mapping_features: Optional[int] = None,
942
- context_embedding_features: Optional[int] = None,
943
- use_snake: bool = False,
944
- ):
945
- super().__init__()
946
- self.use_transformer = num_transformer_blocks > 0
947
-
948
- self.pre_block = ResnetBlock1d(
949
- in_channels=channels,
950
- out_channels=channels,
951
- num_groups=num_groups,
952
- context_mapping_features=context_mapping_features,
953
- use_snake=use_snake
954
- )
955
-
956
- if self.use_transformer:
957
- assert (
958
- (exists(attention_heads) or exists(attention_features))
959
- and exists(attention_multiplier)
960
- )
961
-
962
- if attention_features is None and attention_heads is not None:
963
- attention_features = channels // attention_heads
964
-
965
- if attention_heads is None and attention_features is not None:
966
- attention_heads = channels // attention_features
967
-
968
- self.transformer = Transformer1d(
969
- num_layers=num_transformer_blocks,
970
- channels=channels,
971
- num_heads=attention_heads,
972
- head_features=attention_features,
973
- multiplier=attention_multiplier,
974
- context_features=context_embedding_features,
975
- )
976
-
977
- self.post_block = ResnetBlock1d(
978
- in_channels=channels,
979
- out_channels=channels,
980
- num_groups=num_groups,
981
- context_mapping_features=context_mapping_features,
982
- use_snake=use_snake
983
- )
984
-
985
- def forward(
986
- self,
987
- x: Tensor,
988
- *,
989
- mapping: Optional[Tensor] = None,
990
- embedding: Optional[Tensor] = None,
991
- embedding_mask: Optional[Tensor] = None,
992
- causal: Optional[bool] = False
993
- ) -> Tensor:
994
- x = self.pre_block(x, mapping=mapping, causal=causal)
995
- if self.use_transformer:
996
- x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
997
- x = self.post_block(x, mapping=mapping, causal=causal)
998
- return x
999
-
1000
-
1001
- """
1002
- UNet
1003
- """
1004
-
1005
-
1006
- class UNet1d(nn.Module):
1007
- def __init__(
1008
- self,
1009
- in_channels: int,
1010
- channels: int,
1011
- multipliers: Sequence[int],
1012
- factors: Sequence[int],
1013
- num_blocks: Sequence[int],
1014
- attentions: Sequence[int],
1015
- patch_size: int = 1,
1016
- resnet_groups: int = 8,
1017
- use_context_time: bool = True,
1018
- kernel_multiplier_downsample: int = 2,
1019
- use_nearest_upsample: bool = False,
1020
- use_skip_scale: bool = True,
1021
- use_snake: bool = False,
1022
- use_stft: bool = False,
1023
- use_stft_context: bool = False,
1024
- out_channels: Optional[int] = None,
1025
- context_features: Optional[int] = None,
1026
- context_features_multiplier: int = 4,
1027
- context_channels: Optional[Sequence[int]] = None,
1028
- context_embedding_features: Optional[int] = None,
1029
- **kwargs,
1030
- ):
1031
- super().__init__()
1032
- out_channels = default(out_channels, in_channels)
1033
- context_channels = list(default(context_channels, []))
1034
- num_layers = len(multipliers) - 1
1035
- use_context_features = exists(context_features)
1036
- use_context_channels = len(context_channels) > 0
1037
- context_mapping_features = None
1038
-
1039
- attention_kwargs, kwargs = groupby("attention_", kwargs, keep_prefix=True)
1040
-
1041
- self.num_layers = num_layers
1042
- self.use_context_time = use_context_time
1043
- self.use_context_features = use_context_features
1044
- self.use_context_channels = use_context_channels
1045
- self.use_stft = use_stft
1046
- self.use_stft_context = use_stft_context
1047
-
1048
- self.context_features = context_features
1049
- context_channels_pad_length = num_layers + 1 - len(context_channels)
1050
- context_channels = context_channels + [0] * context_channels_pad_length
1051
- self.context_channels = context_channels
1052
- self.context_embedding_features = context_embedding_features
1053
-
1054
- if use_context_channels:
1055
- has_context = [c > 0 for c in context_channels]
1056
- self.has_context = has_context
1057
- self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))]
1058
-
1059
- assert (
1060
- len(factors) == num_layers
1061
- and len(attentions) >= num_layers
1062
- and len(num_blocks) == num_layers
1063
- )
1064
-
1065
- if use_context_time or use_context_features:
1066
- context_mapping_features = channels * context_features_multiplier
1067
-
1068
- self.to_mapping = nn.Sequential(
1069
- nn.Linear(context_mapping_features, context_mapping_features),
1070
- nn.GELU(),
1071
- nn.Linear(context_mapping_features, context_mapping_features),
1072
- nn.GELU(),
1073
- )
1074
-
1075
- if use_context_time:
1076
- assert exists(context_mapping_features)
1077
- self.to_time = nn.Sequential(
1078
- TimePositionalEmbedding(
1079
- dim=channels, out_features=context_mapping_features
1080
- ),
1081
- nn.GELU(),
1082
- )
1083
-
1084
- if use_context_features:
1085
- assert exists(context_features) and exists(context_mapping_features)
1086
- self.to_features = nn.Sequential(
1087
- nn.Linear(
1088
- in_features=context_features, out_features=context_mapping_features
1089
- ),
1090
- nn.GELU(),
1091
- )
1092
-
1093
- if use_stft:
1094
- stft_kwargs, kwargs = groupby("stft_", kwargs)
1095
- assert "num_fft" in stft_kwargs, "stft_num_fft required if use_stft=True"
1096
- stft_channels = (stft_kwargs["num_fft"] // 2 + 1) * 2
1097
- in_channels *= stft_channels
1098
- out_channels *= stft_channels
1099
- context_channels[0] *= stft_channels if use_stft_context else 1
1100
- assert exists(in_channels) and exists(out_channels)
1101
- self.stft = STFT(**stft_kwargs)
1102
-
1103
- assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}"
1104
-
1105
- self.to_in = Patcher(
1106
- in_channels=in_channels + context_channels[0],
1107
- out_channels=channels * multipliers[0],
1108
- patch_size=patch_size,
1109
- context_mapping_features=context_mapping_features,
1110
- use_snake=use_snake
1111
- )
1112
-
1113
- self.downsamples = nn.ModuleList(
1114
- [
1115
- DownsampleBlock1d(
1116
- in_channels=channels * multipliers[i],
1117
- out_channels=channels * multipliers[i + 1],
1118
- context_mapping_features=context_mapping_features,
1119
- context_channels=context_channels[i + 1],
1120
- context_embedding_features=context_embedding_features,
1121
- num_layers=num_blocks[i],
1122
- factor=factors[i],
1123
- kernel_multiplier=kernel_multiplier_downsample,
1124
- num_groups=resnet_groups,
1125
- use_pre_downsample=True,
1126
- use_skip=True,
1127
- use_snake=use_snake,
1128
- num_transformer_blocks=attentions[i],
1129
- **attention_kwargs,
1130
- )
1131
- for i in range(num_layers)
1132
- ]
1133
- )
1134
-
1135
- self.bottleneck = BottleneckBlock1d(
1136
- channels=channels * multipliers[-1],
1137
- context_mapping_features=context_mapping_features,
1138
- context_embedding_features=context_embedding_features,
1139
- num_groups=resnet_groups,
1140
- num_transformer_blocks=attentions[-1],
1141
- use_snake=use_snake,
1142
- **attention_kwargs,
1143
- )
1144
-
1145
- self.upsamples = nn.ModuleList(
1146
- [
1147
- UpsampleBlock1d(
1148
- in_channels=channels * multipliers[i + 1],
1149
- out_channels=channels * multipliers[i],
1150
- context_mapping_features=context_mapping_features,
1151
- context_embedding_features=context_embedding_features,
1152
- num_layers=num_blocks[i] + (1 if attentions[i] else 0),
1153
- factor=factors[i],
1154
- use_nearest=use_nearest_upsample,
1155
- num_groups=resnet_groups,
1156
- use_skip_scale=use_skip_scale,
1157
- use_pre_upsample=False,
1158
- use_skip=True,
1159
- use_snake=use_snake,
1160
- skip_channels=channels * multipliers[i + 1],
1161
- num_transformer_blocks=attentions[i],
1162
- **attention_kwargs,
1163
- )
1164
- for i in reversed(range(num_layers))
1165
- ]
1166
- )
1167
-
1168
- self.to_out = Unpatcher(
1169
- in_channels=channels * multipliers[0],
1170
- out_channels=out_channels,
1171
- patch_size=patch_size,
1172
- context_mapping_features=context_mapping_features,
1173
- use_snake=use_snake
1174
- )
1175
-
1176
- def get_channels(
1177
- self, channels_list: Optional[Sequence[Tensor]] = None, layer: int = 0
1178
- ) -> Optional[Tensor]:
1179
- """Gets context channels at `layer` and checks that shape is correct"""
1180
- use_context_channels = self.use_context_channels and self.has_context[layer]
1181
- if not use_context_channels:
1182
- return None
1183
- assert exists(channels_list), "Missing context"
1184
- # Get channels index (skipping zero channel contexts)
1185
- channels_id = self.channels_ids[layer]
1186
- # Get channels
1187
- channels = channels_list[channels_id]
1188
- message = f"Missing context for layer {layer} at index {channels_id}"
1189
- assert exists(channels), message
1190
- # Check channels
1191
- num_channels = self.context_channels[layer]
1192
- message = f"Expected context with {num_channels} channels at idx {channels_id}"
1193
- assert channels.shape[1] == num_channels, message
1194
- # STFT channels if requested
1195
- channels = self.stft.encode1d(channels) if self.use_stft_context else channels # type: ignore # noqa
1196
- return channels
1197
-
1198
- def get_mapping(
1199
- self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
1200
- ) -> Optional[Tensor]:
1201
- """Combines context time features and features into mapping"""
1202
- items, mapping = [], None
1203
- # Compute time features
1204
- if self.use_context_time:
1205
- assert_message = "use_context_time=True but no time features provided"
1206
- assert exists(time), assert_message
1207
- items += [self.to_time(time)]
1208
- # Compute features
1209
- if self.use_context_features:
1210
- assert_message = "context_features exists but no features provided"
1211
- assert exists(features), assert_message
1212
- items += [self.to_features(features)]
1213
- # Compute joint mapping
1214
- if self.use_context_time or self.use_context_features:
1215
- mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
1216
- mapping = self.to_mapping(mapping)
1217
- return mapping
1218
-
1219
- def forward(
1220
- self,
1221
- x: Tensor,
1222
- time: Optional[Tensor] = None,
1223
- *,
1224
- features: Optional[Tensor] = None,
1225
- channels_list: Optional[Sequence[Tensor]] = None,
1226
- embedding: Optional[Tensor] = None,
1227
- embedding_mask: Optional[Tensor] = None,
1228
- causal: Optional[bool] = False,
1229
- ) -> Tensor:
1230
- channels = self.get_channels(channels_list, layer=0)
1231
- # Apply stft if required
1232
- x = self.stft.encode1d(x) if self.use_stft else x # type: ignore
1233
- # Concat context channels at layer 0 if provided
1234
- x = torch.cat([x, channels], dim=1) if exists(channels) else x
1235
- # Compute mapping from time and features
1236
- mapping = self.get_mapping(time, features)
1237
- x = self.to_in(x, mapping, causal=causal)
1238
- skips_list = [x]
1239
-
1240
- for i, downsample in enumerate(self.downsamples):
1241
- channels = self.get_channels(channels_list, layer=i + 1)
1242
- x, skips = downsample(
1243
- x, mapping=mapping, channels=channels, embedding=embedding, embedding_mask=embedding_mask, causal=causal
1244
- )
1245
- skips_list += [skips]
1246
-
1247
- x = self.bottleneck(x, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal)
1248
-
1249
- for i, upsample in enumerate(self.upsamples):
1250
- skips = skips_list.pop()
1251
- x = upsample(x, skips=skips, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal)
1252
-
1253
- x += skips_list.pop()
1254
- x = self.to_out(x, mapping, causal=causal)
1255
- x = self.stft.decode1d(x) if self.use_stft else x
1256
-
1257
- return x
1258
-
1259
-
1260
- """ Conditioning Modules """
1261
-
1262
-
1263
- class FixedEmbedding(nn.Module):
1264
- def __init__(self, max_length: int, features: int):
1265
- super().__init__()
1266
- self.max_length = max_length
1267
- self.embedding = nn.Embedding(max_length, features)
1268
-
1269
- def forward(self, x: Tensor) -> Tensor:
1270
- batch_size, length, device = *x.shape[0:2], x.device
1271
- assert_message = "Input sequence length must be <= max_length"
1272
- assert length <= self.max_length, assert_message
1273
- position = torch.arange(length, device=device)
1274
- fixed_embedding = self.embedding(position)
1275
- fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
1276
- return fixed_embedding
1277
-
1278
-
1279
- def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor:
1280
- if proba == 1:
1281
- return torch.ones(shape, device=device, dtype=torch.bool)
1282
- elif proba == 0:
1283
- return torch.zeros(shape, device=device, dtype=torch.bool)
1284
- else:
1285
- return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
1286
-
1287
-
1288
- class UNetCFG1d(UNet1d):
1289
-
1290
- """UNet1d with Classifier-Free Guidance"""
1291
-
1292
- def __init__(
1293
- self,
1294
- context_embedding_max_length: int,
1295
- context_embedding_features: int,
1296
- use_xattn_time: bool = False,
1297
- **kwargs,
1298
- ):
1299
- super().__init__(
1300
- context_embedding_features=context_embedding_features, **kwargs
1301
- )
1302
-
1303
- self.use_xattn_time = use_xattn_time
1304
-
1305
- if use_xattn_time:
1306
- assert exists(context_embedding_features)
1307
- self.to_time_embedding = nn.Sequential(
1308
- TimePositionalEmbedding(
1309
- dim=kwargs["channels"], out_features=context_embedding_features
1310
- ),
1311
- nn.GELU(),
1312
- )
1313
-
1314
- context_embedding_max_length += 1 # Add one for time embedding
1315
-
1316
- self.fixed_embedding = FixedEmbedding(
1317
- max_length=context_embedding_max_length, features=context_embedding_features
1318
- )
1319
-
1320
- def forward( # type: ignore
1321
- self,
1322
- x: Tensor,
1323
- time: Tensor,
1324
- *,
1325
- embedding: Tensor,
1326
- embedding_mask: Optional[Tensor] = None,
1327
- embedding_scale: float = 1.0,
1328
- embedding_mask_proba: float = 0.0,
1329
- batch_cfg: bool = False,
1330
- rescale_cfg: bool = False,
1331
- scale_phi: float = 0.4,
1332
- negative_embedding: Optional[Tensor] = None,
1333
- negative_embedding_mask: Optional[Tensor] = None,
1334
- **kwargs,
1335
- ) -> Tensor:
1336
- b, device = embedding.shape[0], embedding.device
1337
-
1338
- if self.use_xattn_time:
1339
- embedding = torch.cat([embedding, self.to_time_embedding(time).unsqueeze(1)], dim=1)
1340
-
1341
- if embedding_mask is not None:
1342
- embedding_mask = torch.cat([embedding_mask, torch.ones((b, 1), device=device)], dim=1)
1343
-
1344
- fixed_embedding = self.fixed_embedding(embedding)
1345
-
1346
- if embedding_mask_proba > 0.0:
1347
- # Randomly mask embedding
1348
- batch_mask = rand_bool(
1349
- shape=(b, 1, 1), proba=embedding_mask_proba, device=device
1350
- )
1351
- embedding = torch.where(batch_mask, fixed_embedding, embedding)
1352
-
1353
- if embedding_scale != 1.0:
1354
- if batch_cfg:
1355
- batch_x = torch.cat([x, x], dim=0)
1356
- batch_time = torch.cat([time, time], dim=0)
1357
-
1358
- if negative_embedding is not None:
1359
- if negative_embedding_mask is not None:
1360
- negative_embedding_mask = negative_embedding_mask.to(torch.bool).unsqueeze(2)
1361
-
1362
- negative_embedding = torch.where(negative_embedding_mask, negative_embedding, fixed_embedding)
1363
-
1364
- batch_embed = torch.cat([embedding, negative_embedding], dim=0)
1365
-
1366
- else:
1367
- batch_embed = torch.cat([embedding, fixed_embedding], dim=0)
1368
-
1369
- batch_mask = None
1370
- if embedding_mask is not None:
1371
- batch_mask = torch.cat([embedding_mask, embedding_mask], dim=0)
1372
-
1373
- batch_features = None
1374
- features = kwargs.pop("features", None)
1375
- if self.use_context_features:
1376
- batch_features = torch.cat([features, features], dim=0)
1377
-
1378
- batch_channels = None
1379
- channels_list = kwargs.pop("channels_list", None)
1380
- if self.use_context_channels:
1381
- batch_channels = []
1382
- for channels in channels_list:
1383
- batch_channels += [torch.cat([channels, channels], dim=0)]
1384
-
1385
- # Compute both normal and fixed embedding outputs
1386
- batch_out = super().forward(batch_x, batch_time, embedding=batch_embed, embedding_mask=batch_mask, features=batch_features, channels_list=batch_channels, **kwargs)
1387
- out, out_masked = batch_out.chunk(2, dim=0)
1388
-
1389
- else:
1390
- # Compute both normal and fixed embedding outputs
1391
- out = super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs)
1392
- out_masked = super().forward(x, time, embedding=fixed_embedding, embedding_mask=embedding_mask, **kwargs)
1393
-
1394
- out_cfg = out_masked + (out - out_masked) * embedding_scale
1395
-
1396
- if rescale_cfg:
1397
-
1398
- out_std = out.std(dim=1, keepdim=True)
1399
- out_cfg_std = out_cfg.std(dim=1, keepdim=True)
1400
-
1401
- return scale_phi * (out_cfg * (out_std/out_cfg_std)) + (1-scale_phi) * out_cfg
1402
-
1403
- else:
1404
-
1405
- return out_cfg
1406
-
1407
- else:
1408
- return super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs)
1409
-
1410
-
1411
- class UNetNCCA1d(UNet1d):
1412
-
1413
- """UNet1d with Noise Channel Conditioning Augmentation"""
1414
-
1415
- def __init__(self, context_features: int, **kwargs):
1416
- super().__init__(context_features=context_features, **kwargs)
1417
- self.embedder = NumberEmbedder(features=context_features)
1418
-
1419
- def expand(self, x: Any, shape: Tuple[int, ...]) -> Tensor:
1420
- x = x if torch.is_tensor(x) else torch.tensor(x)
1421
- return x.expand(shape)
1422
-
1423
- def forward( # type: ignore
1424
- self,
1425
- x: Tensor,
1426
- time: Tensor,
1427
- *,
1428
- channels_list: Sequence[Tensor],
1429
- channels_augmentation: Union[
1430
- bool, Sequence[bool], Sequence[Sequence[bool]], Tensor
1431
- ] = False,
1432
- channels_scale: Union[
1433
- float, Sequence[float], Sequence[Sequence[float]], Tensor
1434
- ] = 0,
1435
- **kwargs,
1436
- ) -> Tensor:
1437
- b, n = x.shape[0], len(channels_list)
1438
- channels_augmentation = self.expand(channels_augmentation, shape=(b, n)).to(x)
1439
- channels_scale = self.expand(channels_scale, shape=(b, n)).to(x)
1440
-
1441
- # Augmentation (for each channel list item)
1442
- for i in range(n):
1443
- scale = channels_scale[:, i] * channels_augmentation[:, i]
1444
- scale = rearrange(scale, "b -> b 1 1")
1445
- item = channels_list[i]
1446
- channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) # type: ignore # noqa
1447
-
1448
- # Scale embedding (sum reduction if more than one channel list item)
1449
- channels_scale_emb = self.embedder(channels_scale)
1450
- channels_scale_emb = reduce(channels_scale_emb, "b n d -> b d", "sum")
1451
-
1452
- return super().forward(
1453
- x=x,
1454
- time=time,
1455
- channels_list=channels_list,
1456
- features=channels_scale_emb,
1457
- **kwargs,
1458
- )
1459
-
1460
-
1461
- class UNetAll1d(UNetCFG1d, UNetNCCA1d):
1462
- def __init__(self, *args, **kwargs):
1463
- super().__init__(*args, **kwargs)
1464
-
1465
- def forward(self, *args, **kwargs): # type: ignore
1466
- return UNetCFG1d.forward(self, *args, **kwargs)
1467
-
1468
-
1469
- def XUNet1d(type: str = "base", **kwargs) -> UNet1d:
1470
- if type == "base":
1471
- return UNet1d(**kwargs)
1472
- elif type == "all":
1473
- return UNetAll1d(**kwargs)
1474
- elif type == "cfg":
1475
- return UNetCFG1d(**kwargs)
1476
- elif type == "ncca":
1477
- return UNetNCCA1d(**kwargs)
1478
- else:
1479
- raise ValueError(f"Unknown XUNet1d type: {type}")
1480
-
1481
- class NumberEmbedder(nn.Module):
1482
- def __init__(
1483
- self,
1484
- features: int,
1485
- dim: int = 256,
1486
- ):
1487
- super().__init__()
1488
- self.features = features
1489
- self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
1490
-
1491
- def forward(self, x: Union[List[float], Tensor]) -> Tensor:
1492
- if not torch.is_tensor(x):
1493
- device = next(self.embedding.parameters()).device
1494
- x = torch.tensor(x, device=device)
1495
- assert isinstance(x, Tensor)
1496
- shape = x.shape
1497
- x = rearrange(x, "... -> (...)")
1498
- embedding = self.embedding(x)
1499
- x = embedding.view(*shape, self.features)
1500
- return x # type: ignore
1501
-
1502
-
1503
- """
1504
- Audio Transforms
1505
- """
1506
-
1507
-
1508
- class STFT(nn.Module):
1509
- """Helper for torch stft and istft"""
1510
-
1511
- def __init__(
1512
- self,
1513
- num_fft: int = 1023,
1514
- hop_length: int = 256,
1515
- window_length: Optional[int] = None,
1516
- length: Optional[int] = None,
1517
- use_complex: bool = False,
1518
- ):
1519
- super().__init__()
1520
- self.num_fft = num_fft
1521
- self.hop_length = default(hop_length, floor(num_fft // 4))
1522
- self.window_length = default(window_length, num_fft)
1523
- self.length = length
1524
- self.register_buffer("window", torch.hann_window(self.window_length))
1525
- self.use_complex = use_complex
1526
-
1527
- def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]:
1528
- b = wave.shape[0]
1529
- wave = rearrange(wave, "b c t -> (b c) t")
1530
-
1531
- stft = torch.stft(
1532
- wave,
1533
- n_fft=self.num_fft,
1534
- hop_length=self.hop_length,
1535
- win_length=self.window_length,
1536
- window=self.window, # type: ignore
1537
- return_complex=True,
1538
- normalized=True,
1539
- )
1540
-
1541
- if self.use_complex:
1542
- # Returns real and imaginary
1543
- stft_a, stft_b = stft.real, stft.imag
1544
- else:
1545
- # Returns magnitude and phase matrices
1546
- magnitude, phase = torch.abs(stft), torch.angle(stft)
1547
- stft_a, stft_b = magnitude, phase
1548
-
1549
- return rearrange_many((stft_a, stft_b), "(b c) f l -> b c f l", b=b)
1550
-
1551
- def decode(self, stft_a: Tensor, stft_b: Tensor) -> Tensor:
1552
- b, l = stft_a.shape[0], stft_a.shape[-1] # noqa
1553
- length = closest_power_2(l * self.hop_length)
1554
-
1555
- stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> (b c) f l")
1556
-
1557
- if self.use_complex:
1558
- real, imag = stft_a, stft_b
1559
- else:
1560
- magnitude, phase = stft_a, stft_b
1561
- real, imag = magnitude * torch.cos(phase), magnitude * torch.sin(phase)
1562
-
1563
- stft = torch.stack([real, imag], dim=-1)
1564
-
1565
- wave = torch.istft(
1566
- stft,
1567
- n_fft=self.num_fft,
1568
- hop_length=self.hop_length,
1569
- win_length=self.window_length,
1570
- window=self.window, # type: ignore
1571
- length=default(self.length, length),
1572
- normalized=True,
1573
- )
1574
-
1575
- return rearrange(wave, "(b c) t -> b c t", b=b)
1576
-
1577
- def encode1d(
1578
- self, wave: Tensor, stacked: bool = True
1579
- ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
1580
- stft_a, stft_b = self.encode(wave)
1581
- stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> b (c f) l")
1582
- return torch.cat((stft_a, stft_b), dim=1) if stacked else (stft_a, stft_b)
1583
-
1584
- def decode1d(self, stft_pair: Tensor) -> Tensor:
1585
- f = self.num_fft // 2 + 1
1586
- stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1)
1587
- stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f)
1588
- return self.decode(stft_a, stft_b)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/models/autoencoders.py DELETED
@@ -1,794 +0,0 @@
1
- import torch
2
- import math
3
- import numpy as np
4
-
5
- from torch import nn
6
- from torch.nn import functional as F
7
- from torchaudio import transforms as T
8
- from alias_free_torch import Activation1d
9
- from dac.nn.layers import WNConv1d, WNConvTranspose1d
10
- from typing import Literal, Dict, Any
11
-
12
- from ..inference.sampling import sample
13
- from ..inference.utils import prepare_audio
14
- from .blocks import SnakeBeta
15
- from .bottleneck import Bottleneck, DiscreteBottleneck
16
- from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper
17
- from .factory import create_pretransform_from_config, create_bottleneck_from_config
18
- from .pretransforms import Pretransform
19
-
20
- def checkpoint(function, *args, **kwargs):
21
- kwargs.setdefault("use_reentrant", False)
22
- return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
23
-
24
- def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
25
- if activation == "elu":
26
- act = nn.ELU()
27
- elif activation == "snake":
28
- act = SnakeBeta(channels)
29
- elif activation == "none":
30
- act = nn.Identity()
31
- else:
32
- raise ValueError(f"Unknown activation {activation}")
33
-
34
- if antialias:
35
- act = Activation1d(act)
36
-
37
- return act
38
-
39
- class ResidualUnit(nn.Module):
40
- def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
41
- super().__init__()
42
-
43
- self.dilation = dilation
44
-
45
- padding = (dilation * (7-1)) // 2
46
-
47
- self.layers = nn.Sequential(
48
- get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
49
- WNConv1d(in_channels=in_channels, out_channels=out_channels,
50
- kernel_size=7, dilation=dilation, padding=padding),
51
- get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
52
- WNConv1d(in_channels=out_channels, out_channels=out_channels,
53
- kernel_size=1)
54
- )
55
-
56
- def forward(self, x):
57
- res = x
58
-
59
- #x = checkpoint(self.layers, x)
60
- x = self.layers(x)
61
-
62
- return x + res
63
-
64
- class EncoderBlock(nn.Module):
65
- def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
66
- super().__init__()
67
-
68
- self.layers = nn.Sequential(
69
- ResidualUnit(in_channels=in_channels,
70
- out_channels=in_channels, dilation=1, use_snake=use_snake),
71
- ResidualUnit(in_channels=in_channels,
72
- out_channels=in_channels, dilation=3, use_snake=use_snake),
73
- ResidualUnit(in_channels=in_channels,
74
- out_channels=in_channels, dilation=9, use_snake=use_snake),
75
- get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
76
- WNConv1d(in_channels=in_channels, out_channels=out_channels,
77
- kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
78
- )
79
-
80
- def forward(self, x):
81
- return self.layers(x)
82
-
83
- class DecoderBlock(nn.Module):
84
- def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
85
- super().__init__()
86
-
87
- if use_nearest_upsample:
88
- upsample_layer = nn.Sequential(
89
- nn.Upsample(scale_factor=stride, mode="nearest"),
90
- WNConv1d(in_channels=in_channels,
91
- out_channels=out_channels,
92
- kernel_size=2*stride,
93
- stride=1,
94
- bias=False,
95
- padding='same')
96
- )
97
- else:
98
- upsample_layer = WNConvTranspose1d(in_channels=in_channels,
99
- out_channels=out_channels,
100
- kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
101
-
102
- self.layers = nn.Sequential(
103
- get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
104
- upsample_layer,
105
- ResidualUnit(in_channels=out_channels, out_channels=out_channels,
106
- dilation=1, use_snake=use_snake),
107
- ResidualUnit(in_channels=out_channels, out_channels=out_channels,
108
- dilation=3, use_snake=use_snake),
109
- ResidualUnit(in_channels=out_channels, out_channels=out_channels,
110
- dilation=9, use_snake=use_snake),
111
- )
112
-
113
- def forward(self, x):
114
- return self.layers(x)
115
-
116
- class OobleckEncoder(nn.Module):
117
- def __init__(self,
118
- in_channels=2,
119
- channels=128,
120
- latent_dim=32,
121
- c_mults = [1, 2, 4, 8],
122
- strides = [2, 4, 8, 8],
123
- use_snake=False,
124
- antialias_activation=False
125
- ):
126
- super().__init__()
127
-
128
- c_mults = [1] + c_mults
129
-
130
- self.depth = len(c_mults)
131
-
132
- layers = [
133
- WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
134
- ]
135
-
136
- for i in range(self.depth-1):
137
- layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
138
-
139
- layers += [
140
- get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
141
- WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
142
- ]
143
-
144
- self.layers = nn.Sequential(*layers)
145
-
146
- def forward(self, x):
147
- return self.layers(x)
148
-
149
-
150
- class OobleckDecoder(nn.Module):
151
- def __init__(self,
152
- out_channels=2,
153
- channels=128,
154
- latent_dim=32,
155
- c_mults = [1, 2, 4, 8],
156
- strides = [2, 4, 8, 8],
157
- use_snake=False,
158
- antialias_activation=False,
159
- use_nearest_upsample=False,
160
- final_tanh=True):
161
- super().__init__()
162
-
163
- c_mults = [1] + c_mults
164
-
165
- self.depth = len(c_mults)
166
-
167
- layers = [
168
- WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
169
- ]
170
-
171
- for i in range(self.depth-1, 0, -1):
172
- layers += [DecoderBlock(
173
- in_channels=c_mults[i]*channels,
174
- out_channels=c_mults[i-1]*channels,
175
- stride=strides[i-1],
176
- use_snake=use_snake,
177
- antialias_activation=antialias_activation,
178
- use_nearest_upsample=use_nearest_upsample
179
- )
180
- ]
181
-
182
- layers += [
183
- get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
184
- WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
185
- nn.Tanh() if final_tanh else nn.Identity()
186
- ]
187
-
188
- self.layers = nn.Sequential(*layers)
189
-
190
- def forward(self, x):
191
- return self.layers(x)
192
-
193
-
194
- class DACEncoderWrapper(nn.Module):
195
- def __init__(self, in_channels=1, **kwargs):
196
- super().__init__()
197
-
198
- from dac.model.dac import Encoder as DACEncoder
199
-
200
- latent_dim = kwargs.pop("latent_dim", None)
201
-
202
- encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"]))
203
- self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs)
204
- self.latent_dim = latent_dim
205
-
206
- # Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility
207
- self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity()
208
-
209
- if in_channels != 1:
210
- self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3)
211
-
212
- def forward(self, x):
213
- x = self.encoder(x)
214
- x = self.proj_out(x)
215
- return x
216
-
217
- class DACDecoderWrapper(nn.Module):
218
- def __init__(self, latent_dim, out_channels=1, **kwargs):
219
- super().__init__()
220
-
221
- from dac.model.dac import Decoder as DACDecoder
222
-
223
- self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels)
224
-
225
- self.latent_dim = latent_dim
226
-
227
- def forward(self, x):
228
- return self.decoder(x)
229
-
230
- class AudioAutoencoder(nn.Module):
231
- def __init__(
232
- self,
233
- encoder,
234
- decoder,
235
- latent_dim,
236
- downsampling_ratio,
237
- sample_rate,
238
- io_channels=2,
239
- bottleneck: Bottleneck = None,
240
- pretransform: Pretransform = None,
241
- in_channels = None,
242
- out_channels = None,
243
- soft_clip = False
244
- ):
245
- super().__init__()
246
-
247
- self.downsampling_ratio = downsampling_ratio
248
- self.sample_rate = sample_rate
249
-
250
- self.latent_dim = latent_dim
251
- self.io_channels = io_channels
252
- self.in_channels = io_channels
253
- self.out_channels = io_channels
254
-
255
- self.min_length = self.downsampling_ratio
256
-
257
- if in_channels is not None:
258
- self.in_channels = in_channels
259
-
260
- if out_channels is not None:
261
- self.out_channels = out_channels
262
-
263
- self.bottleneck = bottleneck
264
-
265
- self.encoder = encoder
266
-
267
- self.decoder = decoder
268
-
269
- self.pretransform = pretransform
270
-
271
- self.soft_clip = soft_clip
272
-
273
- self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
274
-
275
- def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs):
276
-
277
- info = {}
278
-
279
- if self.pretransform is not None and not skip_pretransform:
280
- if self.pretransform.enable_grad:
281
- if iterate_batch:
282
- audios = []
283
- for i in range(audio.shape[0]):
284
- audios.append(self.pretransform.encode(audio[i:i+1]))
285
- audio = torch.cat(audios, dim=0)
286
- else:
287
- audio = self.pretransform.encode(audio)
288
- else:
289
- with torch.no_grad():
290
- if iterate_batch:
291
- audios = []
292
- for i in range(audio.shape[0]):
293
- audios.append(self.pretransform.encode(audio[i:i+1]))
294
- audio = torch.cat(audios, dim=0)
295
- else:
296
- audio = self.pretransform.encode(audio)
297
-
298
- if self.encoder is not None:
299
- if iterate_batch:
300
- latents = []
301
- for i in range(audio.shape[0]):
302
- latents.append(self.encoder(audio[i:i+1]))
303
- latents = torch.cat(latents, dim=0)
304
- else:
305
- latents = self.encoder(audio)
306
- else:
307
- latents = audio
308
-
309
- if self.bottleneck is not None:
310
- # TODO: Add iterate batch logic, needs to merge the info dicts
311
- latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs)
312
-
313
- info.update(bottleneck_info)
314
-
315
- if return_info:
316
- return latents, info
317
-
318
- return latents
319
-
320
- def decode(self, latents, iterate_batch=False, **kwargs):
321
-
322
- if self.bottleneck is not None:
323
- if iterate_batch:
324
- decoded = []
325
- for i in range(latents.shape[0]):
326
- decoded.append(self.bottleneck.decode(latents[i:i+1]))
327
- decoded = torch.cat(decoded, dim=0)
328
- else:
329
- latents = self.bottleneck.decode(latents)
330
-
331
- if iterate_batch:
332
- decoded = []
333
- for i in range(latents.shape[0]):
334
- decoded.append(self.decoder(latents[i:i+1]))
335
- decoded = torch.cat(decoded, dim=0)
336
- else:
337
- decoded = self.decoder(latents, **kwargs)
338
-
339
- if self.pretransform is not None:
340
- if self.pretransform.enable_grad:
341
- if iterate_batch:
342
- decodeds = []
343
- for i in range(decoded.shape[0]):
344
- decodeds.append(self.pretransform.decode(decoded[i:i+1]))
345
- decoded = torch.cat(decodeds, dim=0)
346
- else:
347
- decoded = self.pretransform.decode(decoded)
348
- else:
349
- with torch.no_grad():
350
- if iterate_batch:
351
- decodeds = []
352
- for i in range(latents.shape[0]):
353
- decodeds.append(self.pretransform.decode(decoded[i:i+1]))
354
- decoded = torch.cat(decodeds, dim=0)
355
- else:
356
- decoded = self.pretransform.decode(decoded)
357
-
358
- if self.soft_clip:
359
- decoded = torch.tanh(decoded)
360
-
361
- return decoded
362
-
363
- def decode_tokens(self, tokens, **kwargs):
364
- '''
365
- Decode discrete tokens to audio
366
- Only works with discrete autoencoders
367
- '''
368
-
369
- assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders"
370
-
371
- latents = self.bottleneck.decode_tokens(tokens, **kwargs)
372
-
373
- return self.decode(latents, **kwargs)
374
-
375
-
376
- def preprocess_audio_for_encoder(self, audio, in_sr):
377
- '''
378
- Preprocess single audio tensor (Channels x Length) to be compatible with the encoder.
379
- If the model is mono, stereo audio will be converted to mono.
380
- Audio will be silence-padded to be a multiple of the model's downsampling ratio.
381
- Audio will be resampled to the model's sample rate.
382
- The output will have batch size 1 and be shape (1 x Channels x Length)
383
- '''
384
- return self.preprocess_audio_list_for_encoder([audio], [in_sr])
385
-
386
- def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list):
387
- '''
388
- Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder.
389
- The audio in that list can be of different lengths and channels.
390
- in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio.
391
- All audio will be resampled to the model's sample rate.
392
- Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio.
393
- If the model is mono, all audio will be converted to mono.
394
- The output will be a tensor of shape (Batch x Channels x Length)
395
- '''
396
- batch_size = len(audio_list)
397
- if isinstance(in_sr_list, int):
398
- in_sr_list = [in_sr_list]*batch_size
399
- assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list"
400
- new_audio = []
401
- max_length = 0
402
- # resample & find the max length
403
- for i in range(batch_size):
404
- audio = audio_list[i]
405
- in_sr = in_sr_list[i]
406
- if len(audio.shape) == 3 and audio.shape[0] == 1:
407
- # batchsize 1 was given by accident. Just squeeze it.
408
- audio = audio.squeeze(0)
409
- elif len(audio.shape) == 1:
410
- # Mono signal, channel dimension is missing, unsqueeze it in
411
- audio = audio.unsqueeze(0)
412
- assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension"
413
- # Resample audio
414
- if in_sr != self.sample_rate:
415
- resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device)
416
- audio = resample_tf(audio)
417
- new_audio.append(audio)
418
- if audio.shape[-1] > max_length:
419
- max_length = audio.shape[-1]
420
- # Pad every audio to the same length, multiple of model's downsampling ratio
421
- padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length
422
- for i in range(batch_size):
423
- # Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model
424
- new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length,
425
- target_channels=self.in_channels, device=new_audio[i].device).squeeze(0)
426
- # convert to tensor
427
- return torch.stack(new_audio)
428
-
429
- def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs):
430
- '''
431
- Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder.
432
- If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap.
433
- Overlap and chunk_size params are both measured in number of latents (not audio samples)
434
- # and therefore you likely could use the same values with decode_audio.
435
- A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
436
- Every autoencoder will have a different receptive field size, and thus ideal overlap.
437
- You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff.
438
- The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
439
- Smaller chunk_size uses less memory, but more compute.
440
- The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
441
- For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
442
- '''
443
- if not chunked:
444
- # default behavior. Encode the entire audio in parallel
445
- return self.encode(audio, **kwargs)
446
- else:
447
- # CHUNKED ENCODING
448
- # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
449
- samples_per_latent = self.downsampling_ratio
450
- total_size = audio.shape[2] # in samples
451
- batch_size = audio.shape[0]
452
- chunk_size *= samples_per_latent # converting metric in latents to samples
453
- overlap *= samples_per_latent # converting metric in latents to samples
454
- hop_size = chunk_size - overlap
455
- chunks = []
456
- for i in range(0, total_size - chunk_size + 1, hop_size):
457
- chunk = audio[:,:,i:i+chunk_size]
458
- chunks.append(chunk)
459
- if i+chunk_size != total_size:
460
- # Final chunk
461
- chunk = audio[:,:,-chunk_size:]
462
- chunks.append(chunk)
463
- chunks = torch.stack(chunks)
464
- num_chunks = chunks.shape[0]
465
- # Note: y_size might be a different value from the latent length used in diffusion training
466
- # because we can encode audio of varying lengths
467
- # However, the audio should've been padded to a multiple of samples_per_latent by now.
468
- y_size = total_size // samples_per_latent
469
- # Create an empty latent, we will populate it with chunks as we encode them
470
- y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device)
471
- for i in range(num_chunks):
472
- x_chunk = chunks[i,:]
473
- # encode the chunk
474
- y_chunk = self.encode(x_chunk)
475
- # figure out where to put the audio along the time domain
476
- if i == num_chunks-1:
477
- # final chunk always goes at the end
478
- t_end = y_size
479
- t_start = t_end - y_chunk.shape[2]
480
- else:
481
- t_start = i * hop_size // samples_per_latent
482
- t_end = t_start + chunk_size // samples_per_latent
483
- # remove the edges of the overlaps
484
- ol = overlap//samples_per_latent//2
485
- chunk_start = 0
486
- chunk_end = y_chunk.shape[2]
487
- if i > 0:
488
- # no overlap for the start of the first chunk
489
- t_start += ol
490
- chunk_start += ol
491
- if i < num_chunks-1:
492
- # no overlap for the end of the last chunk
493
- t_end -= ol
494
- chunk_end -= ol
495
- # paste the chunked audio into our y_final output audio
496
- y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
497
- return y_final
498
-
499
- def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs):
500
- '''
501
- Decode latents to audio.
502
- If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents.
503
- A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
504
- Every autoencoder will have a different receptive field size, and thus ideal overlap.
505
- You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff.
506
- The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
507
- Smaller chunk_size uses less memory, but more compute.
508
- The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
509
- For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
510
- '''
511
- if not chunked:
512
- # default behavior. Decode the entire latent in parallel
513
- return self.decode(latents, **kwargs)
514
- else:
515
- # chunked decoding
516
- hop_size = chunk_size - overlap
517
- total_size = latents.shape[2]
518
- batch_size = latents.shape[0]
519
- chunks = []
520
- for i in range(0, total_size - chunk_size + 1, hop_size):
521
- chunk = latents[:,:,i:i+chunk_size]
522
- chunks.append(chunk)
523
- if i+chunk_size != total_size:
524
- # Final chunk
525
- chunk = latents[:,:,-chunk_size:]
526
- chunks.append(chunk)
527
- chunks = torch.stack(chunks)
528
- num_chunks = chunks.shape[0]
529
- # samples_per_latent is just the downsampling ratio
530
- samples_per_latent = self.downsampling_ratio
531
- # Create an empty waveform, we will populate it with chunks as decode them
532
- y_size = total_size * samples_per_latent
533
- y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device)
534
- for i in range(num_chunks):
535
- x_chunk = chunks[i,:]
536
- # decode the chunk
537
- y_chunk = self.decode(x_chunk)
538
- # figure out where to put the audio along the time domain
539
- if i == num_chunks-1:
540
- # final chunk always goes at the end
541
- t_end = y_size
542
- t_start = t_end - y_chunk.shape[2]
543
- else:
544
- t_start = i * hop_size * samples_per_latent
545
- t_end = t_start + chunk_size * samples_per_latent
546
- # remove the edges of the overlaps
547
- ol = (overlap//2) * samples_per_latent
548
- chunk_start = 0
549
- chunk_end = y_chunk.shape[2]
550
- if i > 0:
551
- # no overlap for the start of the first chunk
552
- t_start += ol
553
- chunk_start += ol
554
- if i < num_chunks-1:
555
- # no overlap for the end of the last chunk
556
- t_end -= ol
557
- chunk_end -= ol
558
- # paste the chunked audio into our y_final output audio
559
- y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
560
- return y_final
561
-
562
-
563
- class DiffusionAutoencoder(AudioAutoencoder):
564
- def __init__(
565
- self,
566
- diffusion: ConditionedDiffusionModel,
567
- diffusion_downsampling_ratio,
568
- *args,
569
- **kwargs
570
- ):
571
- super().__init__(*args, **kwargs)
572
-
573
- self.diffusion = diffusion
574
-
575
- self.min_length = self.downsampling_ratio * diffusion_downsampling_ratio
576
-
577
- if self.encoder is not None:
578
- # Shrink the initial encoder parameters to avoid saturated latents
579
- with torch.no_grad():
580
- for param in self.encoder.parameters():
581
- param *= 0.5
582
-
583
- def decode(self, latents, steps=100):
584
-
585
- upsampled_length = latents.shape[2] * self.downsampling_ratio
586
-
587
- if self.bottleneck is not None:
588
- latents = self.bottleneck.decode(latents)
589
-
590
- if self.decoder is not None:
591
- latents = self.decode(latents)
592
-
593
- # Upsample latents to match diffusion length
594
- if latents.shape[2] != upsampled_length:
595
- latents = F.interpolate(latents, size=upsampled_length, mode='nearest')
596
-
597
- noise = torch.randn(latents.shape[0], self.io_channels, upsampled_length, device=latents.device)
598
- decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents)
599
-
600
- if self.pretransform is not None:
601
- if self.pretransform.enable_grad:
602
- decoded = self.pretransform.decode(decoded)
603
- else:
604
- with torch.no_grad():
605
- decoded = self.pretransform.decode(decoded)
606
-
607
- return decoded
608
-
609
- # AE factories
610
-
611
- def create_encoder_from_config(encoder_config: Dict[str, Any]):
612
- encoder_type = encoder_config.get("type", None)
613
- assert encoder_type is not None, "Encoder type must be specified"
614
-
615
- if encoder_type == "oobleck":
616
- encoder = OobleckEncoder(
617
- **encoder_config["config"]
618
- )
619
-
620
- elif encoder_type == "seanet":
621
- from encodec.modules import SEANetEncoder
622
- seanet_encoder_config = encoder_config["config"]
623
-
624
- #SEANet encoder expects strides in reverse order
625
- seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2])))
626
- encoder = SEANetEncoder(
627
- **seanet_encoder_config
628
- )
629
- elif encoder_type == "dac":
630
- dac_config = encoder_config["config"]
631
-
632
- encoder = DACEncoderWrapper(**dac_config)
633
- elif encoder_type == "local_attn":
634
- from .local_attention import TransformerEncoder1D
635
-
636
- local_attn_config = encoder_config["config"]
637
-
638
- encoder = TransformerEncoder1D(
639
- **local_attn_config
640
- )
641
- else:
642
- raise ValueError(f"Unknown encoder type {encoder_type}")
643
-
644
- requires_grad = encoder_config.get("requires_grad", True)
645
- if not requires_grad:
646
- for param in encoder.parameters():
647
- param.requires_grad = False
648
-
649
- return encoder
650
-
651
- def create_decoder_from_config(decoder_config: Dict[str, Any]):
652
- decoder_type = decoder_config.get("type", None)
653
- assert decoder_type is not None, "Decoder type must be specified"
654
-
655
- if decoder_type == "oobleck":
656
- decoder = OobleckDecoder(
657
- **decoder_config["config"]
658
- )
659
- elif decoder_type == "seanet":
660
- from encodec.modules import SEANetDecoder
661
-
662
- decoder = SEANetDecoder(
663
- **decoder_config["config"]
664
- )
665
- elif decoder_type == "dac":
666
- dac_config = decoder_config["config"]
667
-
668
- decoder = DACDecoderWrapper(**dac_config)
669
- elif decoder_type == "local_attn":
670
- from .local_attention import TransformerDecoder1D
671
-
672
- local_attn_config = decoder_config["config"]
673
-
674
- decoder = TransformerDecoder1D(
675
- **local_attn_config
676
- )
677
- else:
678
- raise ValueError(f"Unknown decoder type {decoder_type}")
679
-
680
- requires_grad = decoder_config.get("requires_grad", True)
681
- if not requires_grad:
682
- for param in decoder.parameters():
683
- param.requires_grad = False
684
-
685
- return decoder
686
-
687
- def create_autoencoder_from_config(config: Dict[str, Any]):
688
-
689
- ae_config = config["model"]
690
-
691
- encoder = create_encoder_from_config(ae_config["encoder"])
692
- decoder = create_decoder_from_config(ae_config["decoder"])
693
-
694
- bottleneck = ae_config.get("bottleneck", None)
695
-
696
- latent_dim = ae_config.get("latent_dim", None)
697
- assert latent_dim is not None, "latent_dim must be specified in model config"
698
- downsampling_ratio = ae_config.get("downsampling_ratio", None)
699
- assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
700
- io_channels = ae_config.get("io_channels", None)
701
- assert io_channels is not None, "io_channels must be specified in model config"
702
- sample_rate = config.get("sample_rate", None)
703
- assert sample_rate is not None, "sample_rate must be specified in model config"
704
-
705
- in_channels = ae_config.get("in_channels", None)
706
- out_channels = ae_config.get("out_channels", None)
707
-
708
- pretransform = ae_config.get("pretransform", None)
709
-
710
- if pretransform is not None:
711
- pretransform = create_pretransform_from_config(pretransform, sample_rate)
712
-
713
- if bottleneck is not None:
714
- bottleneck = create_bottleneck_from_config(bottleneck)
715
-
716
- soft_clip = ae_config["decoder"].get("soft_clip", False)
717
-
718
- return AudioAutoencoder(
719
- encoder,
720
- decoder,
721
- io_channels=io_channels,
722
- latent_dim=latent_dim,
723
- downsampling_ratio=downsampling_ratio,
724
- sample_rate=sample_rate,
725
- bottleneck=bottleneck,
726
- pretransform=pretransform,
727
- in_channels=in_channels,
728
- out_channels=out_channels,
729
- soft_clip=soft_clip
730
- )
731
-
732
- def create_diffAE_from_config(config: Dict[str, Any]):
733
-
734
- diffae_config = config["model"]
735
-
736
- if "encoder" in diffae_config:
737
- encoder = create_encoder_from_config(diffae_config["encoder"])
738
- else:
739
- encoder = None
740
-
741
- if "decoder" in diffae_config:
742
- decoder = create_decoder_from_config(diffae_config["decoder"])
743
- else:
744
- decoder = None
745
-
746
- diffusion_model_type = diffae_config["diffusion"]["type"]
747
-
748
- if diffusion_model_type == "DAU1d":
749
- diffusion = DAU1DCondWrapper(**diffae_config["diffusion"]["config"])
750
- elif diffusion_model_type == "adp_1d":
751
- diffusion = UNet1DCondWrapper(**diffae_config["diffusion"]["config"])
752
- elif diffusion_model_type == "dit":
753
- diffusion = DiTWrapper(**diffae_config["diffusion"]["config"])
754
-
755
- latent_dim = diffae_config.get("latent_dim", None)
756
- assert latent_dim is not None, "latent_dim must be specified in model config"
757
- downsampling_ratio = diffae_config.get("downsampling_ratio", None)
758
- assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
759
- io_channels = diffae_config.get("io_channels", None)
760
- assert io_channels is not None, "io_channels must be specified in model config"
761
- sample_rate = config.get("sample_rate", None)
762
- assert sample_rate is not None, "sample_rate must be specified in model config"
763
-
764
- bottleneck = diffae_config.get("bottleneck", None)
765
-
766
- pretransform = diffae_config.get("pretransform", None)
767
-
768
- if pretransform is not None:
769
- pretransform = create_pretransform_from_config(pretransform, sample_rate)
770
-
771
- if bottleneck is not None:
772
- bottleneck = create_bottleneck_from_config(bottleneck)
773
-
774
- diffusion_downsampling_ratio = None,
775
-
776
- if diffusion_model_type == "DAU1d":
777
- diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"])
778
- elif diffusion_model_type == "adp_1d":
779
- diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["factors"])
780
- elif diffusion_model_type == "dit":
781
- diffusion_downsampling_ratio = 1
782
-
783
- return DiffusionAutoencoder(
784
- encoder=encoder,
785
- decoder=decoder,
786
- diffusion=diffusion,
787
- io_channels=io_channels,
788
- sample_rate=sample_rate,
789
- latent_dim=latent_dim,
790
- downsampling_ratio=downsampling_ratio,
791
- diffusion_downsampling_ratio=diffusion_downsampling_ratio,
792
- bottleneck=bottleneck,
793
- pretransform=pretransform
794
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/models/blocks.py DELETED
@@ -1,339 +0,0 @@
1
- from functools import reduce
2
- import math
3
- import numpy as np
4
- import torch
5
- from torch import nn
6
- from torch.nn import functional as F
7
-
8
- from torch.backends.cuda import sdp_kernel
9
- from packaging import version
10
-
11
- from dac.nn.layers import Snake1d
12
-
13
- class ResidualBlock(nn.Module):
14
- def __init__(self, main, skip=None):
15
- super().__init__()
16
- self.main = nn.Sequential(*main)
17
- self.skip = skip if skip else nn.Identity()
18
-
19
- def forward(self, input):
20
- return self.main(input) + self.skip(input)
21
-
22
- class ResConvBlock(ResidualBlock):
23
- def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False):
24
- skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False)
25
- super().__init__([
26
- nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias),
27
- nn.GroupNorm(1, c_mid),
28
- Snake1d(c_mid) if use_snake else nn.GELU(),
29
- nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias),
30
- nn.GroupNorm(1, c_out) if not is_last else nn.Identity(),
31
- (Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(),
32
- ], skip)
33
-
34
- class SelfAttention1d(nn.Module):
35
- def __init__(self, c_in, n_head=1, dropout_rate=0.):
36
- super().__init__()
37
- assert c_in % n_head == 0
38
- self.norm = nn.GroupNorm(1, c_in)
39
- self.n_head = n_head
40
- self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1)
41
- self.out_proj = nn.Conv1d(c_in, c_in, 1)
42
- self.dropout = nn.Dropout(dropout_rate, inplace=True)
43
-
44
- self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
45
-
46
- if not self.use_flash:
47
- return
48
-
49
- device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
50
-
51
- if device_properties.major == 8 and device_properties.minor == 0:
52
- # Use flash attention for A100 GPUs
53
- self.sdp_kernel_config = (True, False, False)
54
- else:
55
- # Don't use flash attention for other GPUs
56
- self.sdp_kernel_config = (False, True, True)
57
-
58
- def forward(self, input):
59
- n, c, s = input.shape
60
- qkv = self.qkv_proj(self.norm(input))
61
- qkv = qkv.view(
62
- [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3)
63
- q, k, v = qkv.chunk(3, dim=1)
64
- scale = k.shape[3]**-0.25
65
-
66
- if self.use_flash:
67
- with sdp_kernel(*self.sdp_kernel_config):
68
- y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s])
69
- else:
70
- att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
71
- y = (att @ v).transpose(2, 3).contiguous().view([n, c, s])
72
-
73
-
74
- return input + self.dropout(self.out_proj(y))
75
-
76
- class SkipBlock(nn.Module):
77
- def __init__(self, *main):
78
- super().__init__()
79
- self.main = nn.Sequential(*main)
80
-
81
- def forward(self, input):
82
- return torch.cat([self.main(input), input], dim=1)
83
-
84
- class FourierFeatures(nn.Module):
85
- def __init__(self, in_features, out_features, std=1.):
86
- super().__init__()
87
- assert out_features % 2 == 0
88
- self.weight = nn.Parameter(torch.randn(
89
- [out_features // 2, in_features]) * std)
90
-
91
- def forward(self, input):
92
- f = 2 * math.pi * input @ self.weight.T
93
- return torch.cat([f.cos(), f.sin()], dim=-1)
94
-
95
- def expand_to_planes(input, shape):
96
- return input[..., None].repeat([1, 1, shape[2]])
97
-
98
- _kernels = {
99
- 'linear':
100
- [1 / 8, 3 / 8, 3 / 8, 1 / 8],
101
- 'cubic':
102
- [-0.01171875, -0.03515625, 0.11328125, 0.43359375,
103
- 0.43359375, 0.11328125, -0.03515625, -0.01171875],
104
- 'lanczos3':
105
- [0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
106
- -0.066637322306633, 0.13550527393817902, 0.44638532400131226,
107
- 0.44638532400131226, 0.13550527393817902, -0.066637322306633,
108
- -0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
109
- }
110
-
111
- class Downsample1d(nn.Module):
112
- def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
113
- super().__init__()
114
- self.pad_mode = pad_mode
115
- kernel_1d = torch.tensor(_kernels[kernel])
116
- self.pad = kernel_1d.shape[0] // 2 - 1
117
- self.register_buffer('kernel', kernel_1d)
118
- self.channels_last = channels_last
119
-
120
- def forward(self, x):
121
- if self.channels_last:
122
- x = x.permute(0, 2, 1)
123
- x = F.pad(x, (self.pad,) * 2, self.pad_mode)
124
- weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
125
- indices = torch.arange(x.shape[1], device=x.device)
126
- weight[indices, indices] = self.kernel.to(weight)
127
- x = F.conv1d(x, weight, stride=2)
128
- if self.channels_last:
129
- x = x.permute(0, 2, 1)
130
- return x
131
-
132
-
133
- class Upsample1d(nn.Module):
134
- def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
135
- super().__init__()
136
- self.pad_mode = pad_mode
137
- kernel_1d = torch.tensor(_kernels[kernel]) * 2
138
- self.pad = kernel_1d.shape[0] // 2 - 1
139
- self.register_buffer('kernel', kernel_1d)
140
- self.channels_last = channels_last
141
-
142
- def forward(self, x):
143
- if self.channels_last:
144
- x = x.permute(0, 2, 1)
145
- x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode)
146
- weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
147
- indices = torch.arange(x.shape[1], device=x.device)
148
- weight[indices, indices] = self.kernel.to(weight)
149
- x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1)
150
- if self.channels_last:
151
- x = x.permute(0, 2, 1)
152
- return x
153
-
154
- def Downsample1d_2(
155
- in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
156
- ) -> nn.Module:
157
- assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
158
-
159
- return nn.Conv1d(
160
- in_channels=in_channels,
161
- out_channels=out_channels,
162
- kernel_size=factor * kernel_multiplier + 1,
163
- stride=factor,
164
- padding=factor * (kernel_multiplier // 2),
165
- )
166
-
167
-
168
- def Upsample1d_2(
169
- in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
170
- ) -> nn.Module:
171
-
172
- if factor == 1:
173
- return nn.Conv1d(
174
- in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1
175
- )
176
-
177
- if use_nearest:
178
- return nn.Sequential(
179
- nn.Upsample(scale_factor=factor, mode="nearest"),
180
- nn.Conv1d(
181
- in_channels=in_channels,
182
- out_channels=out_channels,
183
- kernel_size=3,
184
- padding=1,
185
- ),
186
- )
187
- else:
188
- return nn.ConvTranspose1d(
189
- in_channels=in_channels,
190
- out_channels=out_channels,
191
- kernel_size=factor * 2,
192
- stride=factor,
193
- padding=factor // 2 + factor % 2,
194
- output_padding=factor % 2,
195
- )
196
-
197
- def zero_init(layer):
198
- nn.init.zeros_(layer.weight)
199
- if layer.bias is not None:
200
- nn.init.zeros_(layer.bias)
201
- return layer
202
-
203
- def rms_norm(x, scale, eps):
204
- dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
205
- mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
206
- scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
207
- return x * scale.to(x.dtype)
208
-
209
- #rms_norm = torch.compile(rms_norm)
210
-
211
- class AdaRMSNorm(nn.Module):
212
- def __init__(self, features, cond_features, eps=1e-6):
213
- super().__init__()
214
- self.eps = eps
215
- self.linear = zero_init(nn.Linear(cond_features, features, bias=False))
216
-
217
- def extra_repr(self):
218
- return f"eps={self.eps},"
219
-
220
- def forward(self, x, cond):
221
- return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps)
222
-
223
- def normalize(x, eps=1e-4):
224
- dim = list(range(1, x.ndim))
225
- n = torch.linalg.vector_norm(x, dim=dim, keepdim=True)
226
- alpha = np.sqrt(n.numel() / x.numel())
227
- return x / torch.add(eps, n, alpha=alpha)
228
-
229
- class ForcedWNConv1d(nn.Module):
230
- def __init__(self, in_channels, out_channels, kernel_size=1):
231
- super().__init__()
232
- self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size]))
233
-
234
- def forward(self, x):
235
- if self.training:
236
- with torch.no_grad():
237
- self.weight.copy_(normalize(self.weight))
238
-
239
- fan_in = self.weight[0].numel()
240
-
241
- w = normalize(self.weight) / math.sqrt(fan_in)
242
-
243
- return F.conv1d(x, w, padding='same')
244
-
245
- # Kernels
246
-
247
- use_compile = True
248
-
249
- def compile(function, *args, **kwargs):
250
- if not use_compile:
251
- return function
252
- try:
253
- return torch.compile(function, *args, **kwargs)
254
- except RuntimeError:
255
- return function
256
-
257
-
258
- @compile
259
- def linear_geglu(x, weight, bias=None):
260
- x = x @ weight.mT
261
- if bias is not None:
262
- x = x + bias
263
- x, gate = x.chunk(2, dim=-1)
264
- return x * F.gelu(gate)
265
-
266
-
267
- @compile
268
- def rms_norm(x, scale, eps):
269
- dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
270
- mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
271
- scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
272
- return x * scale.to(x.dtype)
273
-
274
- # Layers
275
-
276
- class LinearGEGLU(nn.Linear):
277
- def __init__(self, in_features, out_features, bias=True):
278
- super().__init__(in_features, out_features * 2, bias=bias)
279
- self.out_features = out_features
280
-
281
- def forward(self, x):
282
- return linear_geglu(x, self.weight, self.bias)
283
-
284
-
285
- class RMSNorm(nn.Module):
286
- def __init__(self, shape, fix_scale = False, eps=1e-6):
287
- super().__init__()
288
- self.eps = eps
289
-
290
- if fix_scale:
291
- self.register_buffer("scale", torch.ones(shape))
292
- else:
293
- self.scale = nn.Parameter(torch.ones(shape))
294
-
295
- def extra_repr(self):
296
- return f"shape={tuple(self.scale.shape)}, eps={self.eps}"
297
-
298
- def forward(self, x):
299
- return rms_norm(x, self.scale, self.eps)
300
-
301
- def snake_beta(x, alpha, beta):
302
- return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
303
-
304
- # try:
305
- # snake_beta = torch.compile(snake_beta)
306
- # except RuntimeError:
307
- # pass
308
-
309
- # Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
310
- # License available in LICENSES/LICENSE_NVIDIA.txt
311
- class SnakeBeta(nn.Module):
312
-
313
- def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
314
- super(SnakeBeta, self).__init__()
315
- self.in_features = in_features
316
-
317
- # initialize alpha
318
- self.alpha_logscale = alpha_logscale
319
- if self.alpha_logscale: # log scale alphas initialized to zeros
320
- self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
321
- self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
322
- else: # linear scale alphas initialized to ones
323
- self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
324
- self.beta = nn.Parameter(torch.ones(in_features) * alpha)
325
-
326
- self.alpha.requires_grad = alpha_trainable
327
- self.beta.requires_grad = alpha_trainable
328
-
329
- self.no_div_by_zero = 0.000000001
330
-
331
- def forward(self, x):
332
- alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
333
- beta = self.beta.unsqueeze(0).unsqueeze(-1)
334
- if self.alpha_logscale:
335
- alpha = torch.exp(alpha)
336
- beta = torch.exp(beta)
337
- x = snake_beta(x, alpha, beta)
338
-
339
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/models/bottleneck.py DELETED
@@ -1,326 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn import functional as F
4
-
5
- from einops import rearrange
6
- from vector_quantize_pytorch import ResidualVQ, FSQ
7
- from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ
8
-
9
- class Bottleneck(nn.Module):
10
- def __init__(self, is_discrete: bool = False):
11
- super().__init__()
12
-
13
- self.is_discrete = is_discrete
14
-
15
- def encode(self, x, return_info=False, **kwargs):
16
- raise NotImplementedError
17
-
18
- def decode(self, x):
19
- raise NotImplementedError
20
-
21
- class DiscreteBottleneck(Bottleneck):
22
- def __init__(self, num_quantizers, codebook_size, tokens_id):
23
- super().__init__(is_discrete=True)
24
-
25
- self.num_quantizers = num_quantizers
26
- self.codebook_size = codebook_size
27
- self.tokens_id = tokens_id
28
-
29
- def decode_tokens(self, codes, **kwargs):
30
- raise NotImplementedError
31
-
32
- class TanhBottleneck(Bottleneck):
33
- def __init__(self):
34
- super().__init__(is_discrete=False)
35
- self.tanh = nn.Tanh()
36
-
37
- def encode(self, x, return_info=False):
38
- info = {}
39
-
40
- x = torch.tanh(x)
41
-
42
- if return_info:
43
- return x, info
44
- else:
45
- return x
46
-
47
- def decode(self, x):
48
- return x
49
-
50
- def vae_sample(mean, scale):
51
- stdev = nn.functional.softplus(scale) + 1e-4
52
- var = stdev * stdev
53
- logvar = torch.log(var)
54
- latents = torch.randn_like(mean) * stdev + mean
55
-
56
- kl = (mean * mean + var - logvar - 1).sum(1).mean()
57
-
58
- return latents, kl
59
-
60
- class VAEBottleneck(Bottleneck):
61
- def __init__(self):
62
- super().__init__(is_discrete=False)
63
-
64
- def encode(self, x, return_info=False, **kwargs):
65
- info = {}
66
-
67
- mean, scale = x.chunk(2, dim=1)
68
-
69
- x, kl = vae_sample(mean, scale)
70
-
71
- info["kl"] = kl
72
-
73
- if return_info:
74
- return x, info
75
- else:
76
- return x
77
-
78
- def decode(self, x):
79
- return x
80
-
81
- def compute_mean_kernel(x, y):
82
- kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
83
- return torch.exp(-kernel_input).mean()
84
-
85
- def compute_mmd(latents):
86
- latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1])
87
- noise = torch.randn_like(latents_reshaped)
88
-
89
- latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped)
90
- noise_kernel = compute_mean_kernel(noise, noise)
91
- latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise)
92
-
93
- mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel
94
- return mmd.mean()
95
-
96
- class WassersteinBottleneck(Bottleneck):
97
- def __init__(self, noise_augment_dim: int = 0):
98
- super().__init__(is_discrete=False)
99
-
100
- self.noise_augment_dim = noise_augment_dim
101
-
102
- def encode(self, x, return_info=False):
103
- info = {}
104
-
105
- if self.training and return_info:
106
- mmd = compute_mmd(x)
107
- info["mmd"] = mmd
108
-
109
- if return_info:
110
- return x, info
111
-
112
- return x
113
-
114
- def decode(self, x):
115
-
116
- if self.noise_augment_dim > 0:
117
- noise = torch.randn(x.shape[0], self.noise_augment_dim,
118
- x.shape[-1]).type_as(x)
119
- x = torch.cat([x, noise], dim=1)
120
-
121
- return x
122
-
123
- class L2Bottleneck(Bottleneck):
124
- def __init__(self):
125
- super().__init__(is_discrete=False)
126
-
127
- def encode(self, x, return_info=False):
128
- info = {}
129
-
130
- x = F.normalize(x, dim=1)
131
-
132
- if return_info:
133
- return x, info
134
- else:
135
- return x
136
-
137
- def decode(self, x):
138
- return F.normalize(x, dim=1)
139
-
140
- class RVQBottleneck(DiscreteBottleneck):
141
- def __init__(self, **quantizer_kwargs):
142
- super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
143
- self.quantizer = ResidualVQ(**quantizer_kwargs)
144
- self.num_quantizers = quantizer_kwargs["num_quantizers"]
145
-
146
- def encode(self, x, return_info=False, **kwargs):
147
- info = {}
148
-
149
- x = rearrange(x, "b c n -> b n c")
150
- x, indices, loss = self.quantizer(x)
151
- x = rearrange(x, "b n c -> b c n")
152
-
153
- info["quantizer_indices"] = indices
154
- info["quantizer_loss"] = loss.mean()
155
-
156
- if return_info:
157
- return x, info
158
- else:
159
- return x
160
-
161
- def decode(self, x):
162
- return x
163
-
164
- def decode_tokens(self, codes, **kwargs):
165
- latents = self.quantizer.get_outputs_from_indices(codes)
166
-
167
- return self.decode(latents, **kwargs)
168
-
169
- class RVQVAEBottleneck(DiscreteBottleneck):
170
- def __init__(self, **quantizer_kwargs):
171
- super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
172
- self.quantizer = ResidualVQ(**quantizer_kwargs)
173
- self.num_quantizers = quantizer_kwargs["num_quantizers"]
174
-
175
- def encode(self, x, return_info=False):
176
- info = {}
177
-
178
- x, kl = vae_sample(*x.chunk(2, dim=1))
179
-
180
- info["kl"] = kl
181
-
182
- x = rearrange(x, "b c n -> b n c")
183
- x, indices, loss = self.quantizer(x)
184
- x = rearrange(x, "b n c -> b c n")
185
-
186
- info["quantizer_indices"] = indices
187
- info["quantizer_loss"] = loss.mean()
188
-
189
- if return_info:
190
- return x, info
191
- else:
192
- return x
193
-
194
- def decode(self, x):
195
- return x
196
-
197
- def decode_tokens(self, codes, **kwargs):
198
- latents = self.quantizer.get_outputs_from_indices(codes)
199
-
200
- return self.decode(latents, **kwargs)
201
-
202
- class DACRVQBottleneck(DiscreteBottleneck):
203
- def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
204
- super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
205
- self.quantizer = DACResidualVQ(**quantizer_kwargs)
206
- self.num_quantizers = quantizer_kwargs["n_codebooks"]
207
- self.quantize_on_decode = quantize_on_decode
208
-
209
- def encode(self, x, return_info=False, **kwargs):
210
- info = {}
211
-
212
- info["pre_quantizer"] = x
213
-
214
- if self.quantize_on_decode:
215
- return x, info if return_info else x
216
-
217
- z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs)
218
-
219
- output = {
220
- "z": z,
221
- "codes": codes,
222
- "latents": latents,
223
- "vq/commitment_loss": commitment_loss,
224
- "vq/codebook_loss": codebook_loss,
225
- }
226
-
227
- output["vq/commitment_loss"] /= self.num_quantizers
228
- output["vq/codebook_loss"] /= self.num_quantizers
229
-
230
- info.update(output)
231
-
232
- if return_info:
233
- return output["z"], info
234
-
235
- return output["z"]
236
-
237
- def decode(self, x):
238
-
239
- if self.quantize_on_decode:
240
- x = self.quantizer(x)[0]
241
-
242
- return x
243
-
244
- def decode_tokens(self, codes, **kwargs):
245
- latents, _, _ = self.quantizer.from_codes(codes)
246
-
247
- return self.decode(latents, **kwargs)
248
-
249
- class DACRVQVAEBottleneck(DiscreteBottleneck):
250
- def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
251
- super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
252
- self.quantizer = DACResidualVQ(**quantizer_kwargs)
253
- self.num_quantizers = quantizer_kwargs["n_codebooks"]
254
- self.quantize_on_decode = quantize_on_decode
255
-
256
- def encode(self, x, return_info=False, n_quantizers: int = None):
257
- info = {}
258
-
259
- mean, scale = x.chunk(2, dim=1)
260
-
261
- x, kl = vae_sample(mean, scale)
262
-
263
- info["pre_quantizer"] = x
264
- info["kl"] = kl
265
-
266
- if self.quantize_on_decode:
267
- return x, info if return_info else x
268
-
269
- z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers)
270
-
271
- output = {
272
- "z": z,
273
- "codes": codes,
274
- "latents": latents,
275
- "vq/commitment_loss": commitment_loss,
276
- "vq/codebook_loss": codebook_loss,
277
- }
278
-
279
- output["vq/commitment_loss"] /= self.num_quantizers
280
- output["vq/codebook_loss"] /= self.num_quantizers
281
-
282
- info.update(output)
283
-
284
- if return_info:
285
- return output["z"], info
286
-
287
- return output["z"]
288
-
289
- def decode(self, x):
290
-
291
- if self.quantize_on_decode:
292
- x = self.quantizer(x)[0]
293
-
294
- return x
295
-
296
- def decode_tokens(self, codes, **kwargs):
297
- latents, _, _ = self.quantizer.from_codes(codes)
298
-
299
- return self.decode(latents, **kwargs)
300
-
301
- class FSQBottleneck(DiscreteBottleneck):
302
- def __init__(self, dim, levels):
303
- super().__init__(num_quantizers = 1, codebook_size = levels ** dim, tokens_id = "quantizer_indices")
304
- self.quantizer = FSQ(levels=[levels] * dim)
305
-
306
- def encode(self, x, return_info=False):
307
- info = {}
308
-
309
- x = rearrange(x, "b c n -> b n c")
310
- x, indices = self.quantizer(x)
311
- x = rearrange(x, "b n c -> b c n")
312
-
313
- info["quantizer_indices"] = indices
314
-
315
- if return_info:
316
- return x, info
317
- else:
318
- return x
319
-
320
- def decode(self, x):
321
- return x
322
-
323
- def decode_tokens(self, tokens, **kwargs):
324
- latents = self.quantizer.indices_to_codes(tokens)
325
-
326
- return self.decode(latents, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/models/codebook_patterns.py DELETED
@@ -1,545 +0,0 @@
1
- # Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/codebooks_patterns.py under MIT License
2
- # License available in LICENSES/LICENSE_META.txt
3
-
4
- from collections import namedtuple
5
- from dataclasses import dataclass
6
- from functools import lru_cache
7
- import logging
8
- import typing as tp
9
-
10
- from abc import ABC, abstractmethod
11
- import torch
12
-
13
- LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index)
14
- PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates
15
- logger = logging.getLogger(__name__)
16
-
17
-
18
- @dataclass
19
- class Pattern:
20
- """Base implementation of a pattern over a sequence with multiple codebooks.
21
-
22
- The codebook pattern consists in a layout, defining for each sequence step
23
- the list of coordinates of each codebook timestep in the resulting interleaved sequence.
24
- The first item of the pattern is always an empty list in order to properly insert a special token
25
- to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
26
- and ``timesteps`` the number of timesteps corresponding to the original sequence.
27
-
28
- The pattern provides convenient methods to build and revert interleaved sequences from it:
29
- ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
30
- to the interleaved sequence of shape [B, K, S] applying the pattern, with B being the batch size,
31
- K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
32
- for the output sequence. The unfilled positions are replaced with a special token and the built sequence
33
- is returned along with a mask indicating valid tokens.
34
- ``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
35
- of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
36
- to fill and specify invalid positions if needed.
37
- See the dedicated methods for more details.
38
- """
39
- # Pattern layout, for each sequence step, we have a list of coordinates
40
- # corresponding to the original codebook timestep and position.
41
- # The first list is always an empty list in order to properly insert
42
- # a special token to start with.
43
- layout: PatternLayout
44
- timesteps: int
45
- n_q: int
46
-
47
- def __post_init__(self):
48
- assert len(self.layout) > 0
49
- self._validate_layout()
50
- self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
51
- self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
52
- logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
53
-
54
- def _validate_layout(self):
55
- """Runs checks on the layout to ensure a valid pattern is defined.
56
- A pattern is considered invalid if:
57
- - Multiple timesteps for a same codebook are defined in the same sequence step
58
- - The timesteps for a given codebook are not in ascending order as we advance in the sequence
59
- (this would mean that we have future timesteps before past timesteps).
60
- """
61
- q_timesteps = {q: 0 for q in range(self.n_q)}
62
- for s, seq_coords in enumerate(self.layout):
63
- if len(seq_coords) > 0:
64
- qs = set()
65
- for coord in seq_coords:
66
- qs.add(coord.q)
67
- last_q_timestep = q_timesteps[coord.q]
68
- assert coord.t >= last_q_timestep, \
69
- f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
70
- q_timesteps[coord.q] = coord.t
71
- # each sequence step contains at max 1 coordinate per codebook
72
- assert len(qs) == len(seq_coords), \
73
- f"Multiple entries for a same codebook are found at step {s}"
74
-
75
- @property
76
- def num_sequence_steps(self):
77
- return len(self.layout) - 1
78
-
79
- @property
80
- def max_delay(self):
81
- max_t_in_seq_coords = 0
82
- for seq_coords in self.layout[1:]:
83
- for coords in seq_coords:
84
- max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
85
- return max_t_in_seq_coords - self.timesteps
86
-
87
- @property
88
- def valid_layout(self):
89
- valid_step = len(self.layout) - self.max_delay
90
- return self.layout[:valid_step]
91
-
92
- def starts_with_special_token(self):
93
- return self.layout[0] == []
94
-
95
- def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
96
- """Get codebook coordinates in the layout that corresponds to the specified timestep t
97
- and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
98
- and the actual codebook coordinates.
99
- """
100
- assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
101
- if q is not None:
102
- assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
103
- coords = []
104
- for s, seq_codes in enumerate(self.layout):
105
- for code in seq_codes:
106
- if code.t == t and (q is None or code.q == q):
107
- coords.append((s, code))
108
- return coords
109
-
110
- def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
111
- return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
112
-
113
- def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
114
- steps_with_timesteps = self.get_steps_with_timestep(t, q)
115
- return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
116
-
117
- def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool,
118
- device: tp.Union[torch.device, str] = 'cpu'):
119
- """Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
120
-
121
- Args:
122
- timesteps (int): Maximum number of timesteps steps to consider.
123
- keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
124
- device (torch.device or str): Device for created tensors.
125
- Returns:
126
- indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
127
- mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
128
- """
129
- assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
130
- assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
131
- # use the proper layout based on whether we limit ourselves to valid steps only or not,
132
- # note that using the valid_layout will result in a truncated sequence up to the valid steps
133
- ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
134
- # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
135
- indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
136
- mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
137
- # fill indexes with last sequence step value that will correspond to our special token
138
- # the last value is n_q * timesteps as we have flattened z and append special token as the last token
139
- # which will correspond to the index: n_q * timesteps
140
- indexes[:] = n_q * timesteps
141
- # iterate over the pattern and fill scattered indexes and mask
142
- for s, sequence_coords in enumerate(ref_layout):
143
- for coords in sequence_coords:
144
- if coords.t < timesteps:
145
- indexes[coords.q, s] = coords.t + coords.q * timesteps
146
- mask[coords.q, s] = 1
147
- indexes = torch.from_numpy(indexes).to(device)
148
- mask = torch.from_numpy(mask).to(device)
149
- return indexes, mask
150
-
151
- def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
152
- """Build sequence corresponding to the pattern from the input tensor z.
153
- The sequence is built using up to sequence_steps if specified, and non-pattern
154
- coordinates are filled with the special token.
155
-
156
- Args:
157
- z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
158
- special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
159
- keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
160
- Steps that are beyond valid steps will be replaced by the special_token in that case.
161
- Returns:
162
- values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
163
- corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
164
- indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
165
- mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
166
- """
167
- B, K, T = z.shape
168
- indexes, mask = self._build_pattern_sequence_scatter_indexes(
169
- T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
170
- )
171
- z = z.view(B, -1)
172
- # we append the special token as the last index of our flattened z tensor
173
- z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
174
- values = z[:, indexes.view(-1)]
175
- values = values.view(B, K, indexes.shape[-1])
176
- return values, indexes, mask
177
-
178
- def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
179
- keep_only_valid_steps: bool = False,
180
- is_model_output: bool = False,
181
- device: tp.Union[torch.device, str] = 'cpu'):
182
- """Builds scatter indexes required to retrieve the original multi-codebook sequence
183
- from interleaving pattern.
184
-
185
- Args:
186
- sequence_steps (int): Sequence steps.
187
- n_q (int): Number of codebooks.
188
- keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
189
- Steps that are beyond valid steps will be replaced by the special_token in that case.
190
- is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
191
- device (torch.device or str): Device for created tensors.
192
- Returns:
193
- indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T].
194
- mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
195
- """
196
- ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
197
- # TODO(jade): Do we want to further truncate to only valid timesteps here as well?
198
- timesteps = self.timesteps
199
- assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
200
- assert sequence_steps <= len(ref_layout), \
201
- f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
202
-
203
- # ensure we take the appropriate indexes to keep the model output from the first special token as well
204
- if is_model_output and self.starts_with_special_token():
205
- ref_layout = ref_layout[1:]
206
-
207
- # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
208
- indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
209
- mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
210
- # fill indexes with last sequence step value that will correspond to our special token
211
- indexes[:] = n_q * sequence_steps
212
- for s, sequence_codes in enumerate(ref_layout):
213
- if s < sequence_steps:
214
- for code in sequence_codes:
215
- if code.t < timesteps:
216
- indexes[code.q, code.t] = s + code.q * sequence_steps
217
- mask[code.q, code.t] = 1
218
- indexes = torch.from_numpy(indexes).to(device)
219
- mask = torch.from_numpy(mask).to(device)
220
- return indexes, mask
221
-
222
- def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
223
- """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
224
- The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
225
- are filled with the special token.
226
-
227
- Args:
228
- s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
229
- special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
230
- Returns:
231
- values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
232
- corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
233
- indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
234
- mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
235
- """
236
- B, K, S = s.shape
237
- indexes, mask = self._build_reverted_sequence_scatter_indexes(
238
- S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
239
- )
240
- s = s.view(B, -1)
241
- # we append the special token as the last index of our flattened z tensor
242
- s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
243
- values = s[:, indexes.view(-1)]
244
- values = values.view(B, K, indexes.shape[-1])
245
- return values, indexes, mask
246
-
247
- def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
248
- """Revert model logits obtained on a sequence built from the pattern
249
- back to a tensor matching the original sequence.
250
-
251
- This method is similar to ``revert_pattern_sequence`` with the following specificities:
252
- 1. It is designed to work with the extra cardinality dimension
253
- 2. We return the logits for the first sequence item that matches the special_token and
254
- which matching target in the original sequence is the first item of the sequence,
255
- while we skip the last logits as there is no matching target
256
- """
257
- B, card, K, S = logits.shape
258
- indexes, mask = self._build_reverted_sequence_scatter_indexes(
259
- S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
260
- )
261
- logits = logits.reshape(B, card, -1)
262
- # we append the special token as the last index of our flattened z tensor
263
- logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S]
264
- values = logits[:, :, indexes.view(-1)]
265
- values = values.view(B, card, K, indexes.shape[-1])
266
- return values, indexes, mask
267
-
268
-
269
- class CodebooksPatternProvider(ABC):
270
- """Abstraction around providing pattern for interleaving codebooks.
271
-
272
- The CodebooksPatternProvider abstraction allows to implement various strategies to
273
- define interleaving pattern of sequences composed of multiple codebooks. For a given
274
- number of codebooks `n_q`, the pattern provider can generate a specified pattern
275
- corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
276
- can be used to construct a new sequence from the original codes respecting the specified
277
- pattern. The pattern is defined as a list of list of code coordinates, code coordinate
278
- being a tuple with the original timestep and codebook to build the new sequence.
279
- Note that all patterns must start with an empty list that is then used to insert a first
280
- sequence step of special tokens in the newly generated sequence.
281
-
282
- Args:
283
- n_q (int): number of codebooks.
284
- cached (bool): if True, patterns for a given length are cached. In general
285
- that should be true for efficiency reason to avoid synchronization points.
286
- """
287
- def __init__(self, n_q: int, cached: bool = True):
288
- assert n_q > 0
289
- self.n_q = n_q
290
- self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore
291
-
292
- @abstractmethod
293
- def get_pattern(self, timesteps: int) -> Pattern:
294
- """Builds pattern with specific interleaving between codebooks.
295
-
296
- Args:
297
- timesteps (int): Total number of timesteps.
298
- """
299
- raise NotImplementedError()
300
-
301
-
302
- class DelayedPatternProvider(CodebooksPatternProvider):
303
- """Provider for delayed pattern across delayed codebooks.
304
- Codebooks are delayed in the sequence and sequence steps will contain codebooks
305
- from different timesteps.
306
-
307
- Example:
308
- Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
309
- [[1, 2, 3, 4],
310
- [1, 2, 3, 4],
311
- [1, 2, 3, 4]]
312
- The resulting sequence obtained from the returned pattern is:
313
- [[S, 1, 2, 3, 4],
314
- [S, S, 1, 2, 3],
315
- [S, S, S, 1, 2]]
316
- (with S being a special token)
317
-
318
- Args:
319
- n_q (int): Number of codebooks.
320
- delays (list of int, optional): Delay for each of the codebooks.
321
- If delays not defined, each codebook is delayed by 1 compared to the previous one.
322
- flatten_first (int): Flatten the first N timesteps.
323
- empty_initial (int): Prepend with N empty list of coordinates.
324
- """
325
- def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
326
- flatten_first: int = 0, empty_initial: int = 0):
327
- super().__init__(n_q)
328
- if delays is None:
329
- delays = list(range(n_q))
330
- self.delays = delays
331
- self.flatten_first = flatten_first
332
- self.empty_initial = empty_initial
333
- assert len(self.delays) == self.n_q
334
- assert sorted(self.delays) == self.delays
335
-
336
- def get_pattern(self, timesteps: int) -> Pattern:
337
- omit_special_token = self.empty_initial < 0
338
- out: PatternLayout = [] if omit_special_token else [[]]
339
- max_delay = max(self.delays)
340
- if self.empty_initial:
341
- out += [[] for _ in range(self.empty_initial)]
342
- if self.flatten_first:
343
- for t in range(min(timesteps, self.flatten_first)):
344
- for q in range(self.n_q):
345
- out.append([LayoutCoord(t, q)])
346
- for t in range(self.flatten_first, timesteps + max_delay):
347
- v = []
348
- for q, delay in enumerate(self.delays):
349
- t_for_q = t - delay
350
- if t_for_q >= self.flatten_first:
351
- v.append(LayoutCoord(t_for_q, q))
352
- out.append(v)
353
- return Pattern(out, n_q=self.n_q, timesteps=timesteps)
354
-
355
-
356
- class ParallelPatternProvider(DelayedPatternProvider):
357
- """Provider for parallel pattern across codebooks.
358
- This pattern provider is a special case of the delayed pattern with actually no delay,
359
- hence delays=repeat(0, n_q).
360
-
361
- Args:
362
- n_q (int): Number of codebooks.
363
- empty_initial (int): Prepend with N empty list of coordinates.
364
- """
365
- def __init__(self, n_q: int, empty_initial: int = 0):
366
- super().__init__(n_q, [0] * n_q, empty_initial=empty_initial)
367
-
368
-
369
- class UnrolledPatternProvider(CodebooksPatternProvider):
370
- """Provider for unrolling codebooks pattern.
371
- This pattern provider enables to represent the codebook flattened completely or only to some extend
372
- while also specifying a given delay between the flattened codebooks representation, allowing to
373
- unroll the codebooks in the sequence.
374
-
375
- Example:
376
- 1. Flattening of the codebooks.
377
- By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q),
378
- taking n_q = 3 and timesteps = 4:
379
- [[1, 2, 3, 4],
380
- [1, 2, 3, 4],
381
- [1, 2, 3, 4]]
382
- will result into:
383
- [[S, S, 1, S, S, 2, S, S, 3, S, S, 4],
384
- [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
385
- [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
386
- 2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step
387
- for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example
388
- taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]:
389
- [[1, 2, 3, 4],
390
- [1, 2, 3, 4],
391
- [1, 2, 3, 4]]
392
- will result into:
393
- [[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
394
- [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
395
- [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
396
- 3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks
397
- allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the
398
- same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1]
399
- and delays = [0, 3, 3]:
400
- [[1, 2, 3, 4],
401
- [1, 2, 3, 4],
402
- [1, 2, 3, 4]]
403
- will result into:
404
- [[S, S, S, 1, S, 2, S, 3, S, 4],
405
- [S, S, S, 1, S, 2, S, 3, S, 4],
406
- [1, 2, 3, S, 4, S, 5, S, 6, S]]
407
-
408
- Args:
409
- n_q (int): Number of codebooks.
410
- flattening (list of int, optional): Flattening schema over the codebooks. If not defined,
411
- the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
412
- have n_q extra steps for each timestep.
413
- delays (list of int, optional): Delay for each of the codebooks. If not defined,
414
- no delay is added and therefore will default to [0] * ``n_q``.
415
- Note that two codebooks that will be flattened to the same inner step
416
- should have the same delay, otherwise the pattern is considered as invalid.
417
- """
418
- FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay'])
419
-
420
- def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None,
421
- delays: tp.Optional[tp.List[int]] = None):
422
- super().__init__(n_q)
423
- if flattening is None:
424
- flattening = list(range(n_q))
425
- if delays is None:
426
- delays = [0] * n_q
427
- assert len(flattening) == n_q
428
- assert len(delays) == n_q
429
- assert sorted(flattening) == flattening
430
- assert sorted(delays) == delays
431
- self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
432
- self.max_delay = max(delays)
433
-
434
- def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]):
435
- """Build a flattened codebooks representation as a dictionary of inner step
436
- and the actual codebook indices corresponding to the flattened codebook. For convenience, we
437
- also store the delay associated to the flattened codebook to avoid maintaining an extra mapping.
438
- """
439
- flattened_codebooks: dict = {}
440
- for q, (inner_step, delay) in enumerate(zip(flattening, delays)):
441
- if inner_step not in flattened_codebooks:
442
- flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay)
443
- else:
444
- flat_codebook = flattened_codebooks[inner_step]
445
- assert flat_codebook.delay == delay, (
446
- "Delay and flattening between codebooks is inconsistent: ",
447
- "two codebooks flattened to the same position should have the same delay."
448
- )
449
- flat_codebook.codebooks.append(q)
450
- flattened_codebooks[inner_step] = flat_codebook
451
- return flattened_codebooks
452
-
453
- @property
454
- def _num_inner_steps(self):
455
- """Number of inner steps to unroll between timesteps in order to flatten the codebooks.
456
- """
457
- return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1
458
-
459
- def num_virtual_steps(self, timesteps: int) -> int:
460
- return timesteps * self._num_inner_steps + 1
461
-
462
- def get_pattern(self, timesteps: int) -> Pattern:
463
- """Builds pattern for delay across codebooks.
464
-
465
- Args:
466
- timesteps (int): Total number of timesteps.
467
- """
468
- # the PatternLayout is built as a tuple of sequence position and list of coordinates
469
- # so that it can be reordered properly given the required delay between codebooks of given timesteps
470
- indexed_out: list = [(-1, [])]
471
- max_timesteps = timesteps + self.max_delay
472
- for t in range(max_timesteps):
473
- # for each timestep, we unroll the flattened codebooks,
474
- # emitting the sequence step with the corresponding delay
475
- for step in range(self._num_inner_steps):
476
- if step in self._flattened_codebooks:
477
- # we have codebooks at this virtual step to emit
478
- step_codebooks = self._flattened_codebooks[step]
479
- t_for_q = t + step_codebooks.delay
480
- coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
481
- if t_for_q < max_timesteps and t < max_timesteps:
482
- indexed_out.append((t_for_q, coords))
483
- else:
484
- # there is no codebook in this virtual step so we emit an empty list
485
- indexed_out.append((t, []))
486
- out = [coords for _, coords in sorted(indexed_out)]
487
- return Pattern(out, n_q=self.n_q, timesteps=timesteps)
488
-
489
-
490
- class CoarseFirstPattern(CodebooksPatternProvider):
491
- """First generates all the codebooks #1 (e.g. coarser), then the remaining ones,
492
- potentially with delays.
493
-
494
- ..Warning:: You must always generate the full training duration at test time, for instance,
495
- 30 seconds, as otherwise, the fine codebooks will start being generated in an unexpected
496
- location. This is due to the non causality of the remaining codebooks with respect to
497
- the first ones.
498
-
499
- Args:
500
- n_q (int): Number of codebooks.
501
- delays (list of int, optional): Delay for each of the codebooks.
502
- If delays not defined, each codebook is delayed by 1 compared to the previous one.
503
- """
504
- def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
505
- super().__init__(n_q)
506
- if delays is None:
507
- delays = [0] * (n_q - 1)
508
- self.delays = delays
509
- assert len(self.delays) == self.n_q - 1
510
- assert sorted(self.delays) == self.delays
511
-
512
- def get_pattern(self, timesteps: int) -> Pattern:
513
- out: PatternLayout = [[]]
514
- for t in range(timesteps):
515
- out.append([LayoutCoord(t, 0)])
516
- max_delay = max(self.delays)
517
- for t in range(timesteps + max_delay):
518
- v = []
519
- for q, delay in enumerate(self.delays):
520
- t_for_q = t - delay
521
- if t_for_q >= 0:
522
- v.append(LayoutCoord(t_for_q, q + 1))
523
- out.append(v)
524
- return Pattern(out, n_q=self.n_q, timesteps=timesteps)
525
-
526
-
527
- class MusicLMPattern(CodebooksPatternProvider):
528
- """Almost MusicLM style pattern. This is equivalent to full flattening
529
- but in a different order.
530
-
531
- Args:
532
- n_q (int): Number of codebooks.
533
- group_by (int): Number of codebooks to group together.
534
- """
535
- def __init__(self, n_q: int, group_by: int = 2):
536
- super().__init__(n_q)
537
- self.group_by = group_by
538
-
539
- def get_pattern(self, timesteps: int) -> Pattern:
540
- out: PatternLayout = [[]]
541
- for offset in range(0, self.n_q, self.group_by):
542
- for t in range(timesteps):
543
- for q in range(offset, offset + self.group_by):
544
- out.append([LayoutCoord(t, q)])
545
- return Pattern(out, n_q=self.n_q, timesteps=timesteps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/models/conditioners.py DELETED
@@ -1,561 +0,0 @@
1
- #Heavily influenced by https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conditioners.py
2
-
3
- import torch
4
- import logging, warnings
5
- import string
6
- import typing as tp
7
- import gc
8
-
9
- from .adp import NumberEmbedder
10
- from ..inference.utils import set_audio_channels
11
- from .factory import create_pretransform_from_config
12
- from .pretransforms import Pretransform
13
- from ..training.utils import copy_state_dict
14
- from .utils import load_ckpt_state_dict
15
-
16
- from torch import nn
17
-
18
- class Conditioner(nn.Module):
19
- def __init__(
20
- self,
21
- dim: int,
22
- output_dim: int,
23
- project_out: bool = False
24
- ):
25
-
26
- super().__init__()
27
-
28
- self.dim = dim
29
- self.output_dim = output_dim
30
- self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity()
31
-
32
- def forward(self, x: tp.Any) -> tp.Any:
33
- raise NotImplementedError()
34
-
35
- class IntConditioner(Conditioner):
36
- def __init__(self,
37
- output_dim: int,
38
- min_val: int=0,
39
- max_val: int=512
40
- ):
41
- super().__init__(output_dim, output_dim)
42
-
43
- self.min_val = min_val
44
- self.max_val = max_val
45
- self.int_embedder = nn.Embedding(max_val - min_val + 1, output_dim).requires_grad_(True)
46
-
47
- def forward(self, ints: tp.List[int], device=None) -> tp.Any:
48
-
49
- #self.int_embedder.to(device)
50
-
51
- ints = torch.tensor(ints).to(device)
52
- ints = ints.clamp(self.min_val, self.max_val)
53
-
54
- int_embeds = self.int_embedder(ints).unsqueeze(1)
55
-
56
- return [int_embeds, torch.ones(int_embeds.shape[0], 1).to(device)]
57
-
58
- class NumberConditioner(Conditioner):
59
- '''
60
- Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings
61
- '''
62
- def __init__(self,
63
- output_dim: int,
64
- min_val: float=0,
65
- max_val: float=1
66
- ):
67
- super().__init__(output_dim, output_dim)
68
-
69
- self.min_val = min_val
70
- self.max_val = max_val
71
-
72
- self.embedder = NumberEmbedder(features=output_dim)
73
-
74
- def forward(self, floats: tp.List[float], device=None) -> tp.Any:
75
-
76
- # Cast the inputs to floats
77
- floats = [float(x) for x in floats]
78
-
79
- floats = torch.tensor(floats).to(device)
80
-
81
- floats = floats.clamp(self.min_val, self.max_val)
82
-
83
- normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val)
84
-
85
- # Cast floats to same type as embedder
86
- embedder_dtype = next(self.embedder.parameters()).dtype
87
- normalized_floats = normalized_floats.to(embedder_dtype)
88
-
89
- float_embeds = self.embedder(normalized_floats).unsqueeze(1)
90
-
91
- return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)]
92
-
93
- class CLAPTextConditioner(Conditioner):
94
- def __init__(self,
95
- output_dim: int,
96
- clap_ckpt_path,
97
- use_text_features = False,
98
- feature_layer_ix: int = -1,
99
- audio_model_type="HTSAT-base",
100
- enable_fusion=True,
101
- project_out: bool = False,
102
- finetune: bool = False):
103
- super().__init__(768 if use_text_features else 512, output_dim, project_out=project_out)
104
-
105
- self.use_text_features = use_text_features
106
- self.feature_layer_ix = feature_layer_ix
107
- self.finetune = finetune
108
-
109
- # Suppress logging from transformers
110
- previous_level = logging.root.manager.disable
111
- logging.disable(logging.ERROR)
112
- with warnings.catch_warnings():
113
- warnings.simplefilter("ignore")
114
- try:
115
- import laion_clap
116
- from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict
117
-
118
- model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu')
119
-
120
- if self.finetune:
121
- self.model = model
122
- else:
123
- self.__dict__["model"] = model
124
-
125
- state_dict = clap_load_state_dict(clap_ckpt_path)
126
- self.model.model.load_state_dict(state_dict, strict=False)
127
-
128
- if self.finetune:
129
- self.model.model.text_branch.requires_grad_(True)
130
- self.model.model.text_branch.train()
131
- else:
132
- self.model.model.text_branch.requires_grad_(False)
133
- self.model.model.text_branch.eval()
134
-
135
- finally:
136
- logging.disable(previous_level)
137
-
138
- del self.model.model.audio_branch
139
-
140
- gc.collect()
141
- torch.cuda.empty_cache()
142
-
143
- def get_clap_features(self, prompts, layer_ix=-2, device: tp.Any = "cuda"):
144
- prompt_tokens = self.model.tokenizer(prompts)
145
- attention_mask = prompt_tokens["attention_mask"].to(device=device, non_blocking=True)
146
- prompt_features = self.model.model.text_branch(
147
- input_ids=prompt_tokens["input_ids"].to(device=device, non_blocking=True),
148
- attention_mask=attention_mask,
149
- output_hidden_states=True
150
- )["hidden_states"][layer_ix]
151
-
152
- return prompt_features, attention_mask
153
-
154
- def forward(self, texts: tp.List[str], device: tp.Any = "cuda") -> tp.Any:
155
- self.model.to(device)
156
-
157
- if self.use_text_features:
158
- if len(texts) == 1:
159
- text_features, text_attention_mask = self.get_clap_features([texts[0], ""], layer_ix=self.feature_layer_ix, device=device)
160
- text_features = text_features[:1, ...]
161
- text_attention_mask = text_attention_mask[:1, ...]
162
- else:
163
- text_features, text_attention_mask = self.get_clap_features(texts, layer_ix=self.feature_layer_ix, device=device)
164
- return [self.proj_out(text_features), text_attention_mask]
165
-
166
- # Fix for CLAP bug when only one text is passed
167
- if len(texts) == 1:
168
- text_embedding = self.model.get_text_embedding([texts[0], ""], use_tensor=True)[:1, ...]
169
- else:
170
- text_embedding = self.model.get_text_embedding(texts, use_tensor=True)
171
-
172
- text_embedding = text_embedding.unsqueeze(1).to(device)
173
-
174
- return [self.proj_out(text_embedding), torch.ones(text_embedding.shape[0], 1).to(device)]
175
-
176
- class CLAPAudioConditioner(Conditioner):
177
- def __init__(self,
178
- output_dim: int,
179
- clap_ckpt_path,
180
- audio_model_type="HTSAT-base",
181
- enable_fusion=True,
182
- project_out: bool = False):
183
- super().__init__(512, output_dim, project_out=project_out)
184
-
185
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
186
-
187
- # Suppress logging from transformers
188
- previous_level = logging.root.manager.disable
189
- logging.disable(logging.ERROR)
190
- with warnings.catch_warnings():
191
- warnings.simplefilter("ignore")
192
- try:
193
- import laion_clap
194
- from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict
195
-
196
- model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu')
197
-
198
- if self.finetune:
199
- self.model = model
200
- else:
201
- self.__dict__["model"] = model
202
-
203
- state_dict = clap_load_state_dict(clap_ckpt_path)
204
- self.model.model.load_state_dict(state_dict, strict=False)
205
-
206
- if self.finetune:
207
- self.model.model.audio_branch.requires_grad_(True)
208
- self.model.model.audio_branch.train()
209
- else:
210
- self.model.model.audio_branch.requires_grad_(False)
211
- self.model.model.audio_branch.eval()
212
-
213
- finally:
214
- logging.disable(previous_level)
215
-
216
- del self.model.model.text_branch
217
-
218
- gc.collect()
219
- torch.cuda.empty_cache()
220
-
221
- def forward(self, audios: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]] , device: tp.Any = "cuda") -> tp.Any:
222
-
223
- self.model.to(device)
224
-
225
- if isinstance(audios, list) or isinstance(audios, tuple):
226
- audios = torch.cat(audios, dim=0)
227
-
228
- # Convert to mono
229
- mono_audios = audios.mean(dim=1)
230
-
231
- with torch.cuda.amp.autocast(enabled=False):
232
- audio_embedding = self.model.get_audio_embedding_from_data(mono_audios.float(), use_tensor=True)
233
-
234
- audio_embedding = audio_embedding.unsqueeze(1).to(device)
235
-
236
- return [self.proj_out(audio_embedding), torch.ones(audio_embedding.shape[0], 1).to(device)]
237
-
238
- class T5Conditioner(Conditioner):
239
-
240
- T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
241
- "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
242
- "google/flan-t5-xl", "google/flan-t5-xxl"]
243
-
244
- T5_MODEL_DIMS = {
245
- "t5-small": 512,
246
- "t5-base": 768,
247
- "t5-large": 1024,
248
- "t5-3b": 1024,
249
- "t5-11b": 1024,
250
- "t5-xl": 2048,
251
- "t5-xxl": 4096,
252
- "google/flan-t5-small": 512,
253
- "google/flan-t5-base": 768,
254
- "google/flan-t5-large": 1024,
255
- "google/flan-t5-3b": 1024,
256
- "google/flan-t5-11b": 1024,
257
- "google/flan-t5-xl": 2048,
258
- "google/flan-t5-xxl": 4096,
259
- }
260
-
261
- def __init__(
262
- self,
263
- output_dim: int,
264
- t5_model_name: str = "t5-base",
265
- max_length: str = 128,
266
- enable_grad: bool = False,
267
- project_out: bool = False
268
- ):
269
- assert t5_model_name in self.T5_MODELS, f"Unknown T5 model name: {t5_model_name}"
270
- super().__init__(self.T5_MODEL_DIMS[t5_model_name], output_dim, project_out=project_out)
271
-
272
- from transformers import T5EncoderModel, AutoTokenizer
273
-
274
- self.max_length = max_length
275
- self.enable_grad = enable_grad
276
-
277
- # Suppress logging from transformers
278
- previous_level = logging.root.manager.disable
279
- logging.disable(logging.ERROR)
280
- with warnings.catch_warnings():
281
- warnings.simplefilter("ignore")
282
- try:
283
- # self.tokenizer = T5Tokenizer.from_pretrained(t5_model_name, model_max_length = max_length)
284
- # model = T5EncoderModel.from_pretrained(t5_model_name, max_length=max_length).train(enable_grad).requires_grad_(enable_grad)
285
- self.tokenizer = AutoTokenizer.from_pretrained(t5_model_name)
286
- model = T5EncoderModel.from_pretrained(t5_model_name).train(enable_grad).requires_grad_(enable_grad).to(torch.float16)
287
- finally:
288
- logging.disable(previous_level)
289
-
290
- if self.enable_grad:
291
- self.model = model
292
- else:
293
- self.__dict__["model"] = model
294
-
295
-
296
- def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
297
-
298
- self.model.to(device)
299
- self.proj_out.to(device)
300
-
301
- encoded = self.tokenizer(
302
- texts,
303
- truncation=True,
304
- max_length=self.max_length,
305
- padding="max_length",
306
- return_tensors="pt",
307
- )
308
-
309
- input_ids = encoded["input_ids"].to(device)
310
- attention_mask = encoded["attention_mask"].to(device).to(torch.bool)
311
-
312
- self.model.eval()
313
-
314
- with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad):
315
- embeddings = self.model(
316
- input_ids=input_ids, attention_mask=attention_mask
317
- )["last_hidden_state"]
318
-
319
- embeddings = self.proj_out(embeddings.float())
320
-
321
- embeddings = embeddings * attention_mask.unsqueeze(-1).float()
322
-
323
- return embeddings, attention_mask
324
-
325
- class PhonemeConditioner(Conditioner):
326
- """
327
- A conditioner that turns text into phonemes and embeds them using a lookup table
328
- Only works for English text
329
-
330
- Args:
331
- output_dim: the dimension of the output embeddings
332
- max_length: the maximum number of phonemes to embed
333
- project_out: whether to add another linear projection to the output embeddings
334
- """
335
-
336
- def __init__(
337
- self,
338
- output_dim: int,
339
- max_length: int = 1024,
340
- project_out: bool = False,
341
- ):
342
- super().__init__(output_dim, output_dim, project_out=project_out)
343
-
344
- from g2p_en import G2p
345
-
346
- self.max_length = max_length
347
-
348
- self.g2p = G2p()
349
-
350
- # Reserving 0 for padding, 1 for ignored
351
- self.phoneme_embedder = nn.Embedding(len(self.g2p.phonemes) + 2, output_dim)
352
-
353
- def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
354
-
355
- self.phoneme_embedder.to(device)
356
- self.proj_out.to(device)
357
-
358
- batch_phonemes = [self.g2p(text) for text in texts] # shape [batch_size, length]
359
-
360
- phoneme_ignore = [" ", *string.punctuation]
361
-
362
- # Remove ignored phonemes and cut to max length
363
- batch_phonemes = [[p if p not in phoneme_ignore else "_" for p in phonemes] for phonemes in batch_phonemes]
364
-
365
- # Convert to ids
366
- phoneme_ids = [[self.g2p.p2idx[p] + 2 if p in self.g2p.p2idx else 1 for p in phonemes] for phonemes in batch_phonemes]
367
-
368
- #Pad to match longest and make a mask tensor for the padding
369
- longest = max([len(ids) for ids in phoneme_ids])
370
- phoneme_ids = [ids + [0] * (longest - len(ids)) for ids in phoneme_ids]
371
-
372
- phoneme_ids = torch.tensor(phoneme_ids).to(device)
373
-
374
- # Convert to embeddings
375
- phoneme_embeds = self.phoneme_embedder(phoneme_ids)
376
-
377
- phoneme_embeds = self.proj_out(phoneme_embeds)
378
-
379
- return phoneme_embeds, torch.ones(phoneme_embeds.shape[0], phoneme_embeds.shape[1]).to(device)
380
-
381
- class TokenizerLUTConditioner(Conditioner):
382
- """
383
- A conditioner that embeds text using a lookup table on a pretrained tokenizer's vocabulary
384
-
385
- Args:
386
- tokenizer_name: the name of the tokenizer from the Hugging Face transformers library
387
- output_dim: the dimension of the output embeddings
388
- max_length: the maximum length of the text to embed
389
- project_out: whether to add another linear projection to the output embeddings
390
- """
391
-
392
- def __init__(
393
- self,
394
- tokenizer_name: str, # Name of a tokenizer from the Hugging Face transformers library
395
- output_dim: int,
396
- max_length: int = 1024,
397
- project_out: bool = False,
398
- ):
399
- super().__init__(output_dim, output_dim, project_out=project_out)
400
-
401
- from transformers import AutoTokenizer
402
-
403
- # Suppress logging from transformers
404
- previous_level = logging.root.manager.disable
405
- logging.disable(logging.ERROR)
406
- with warnings.catch_warnings():
407
- warnings.simplefilter("ignore")
408
- try:
409
- self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
410
- finally:
411
- logging.disable(previous_level)
412
-
413
- self.max_length = max_length
414
-
415
- self.token_embedder = nn.Embedding(len(self.tokenizer), output_dim)
416
-
417
- def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
418
- self.proj_out.to(device)
419
-
420
- encoded = self.tokenizer(
421
- texts,
422
- truncation=True,
423
- max_length=self.max_length,
424
- padding="max_length",
425
- return_tensors="pt",
426
- )
427
-
428
- input_ids = encoded["input_ids"].to(device)
429
- attention_mask = encoded["attention_mask"].to(device).to(torch.bool)
430
-
431
- embeddings = self.token_embedder(input_ids)
432
-
433
- embeddings = self.proj_out(embeddings)
434
-
435
- embeddings = embeddings * attention_mask.unsqueeze(-1).float()
436
-
437
- return embeddings, attention_mask
438
-
439
- class PretransformConditioner(Conditioner):
440
- """
441
- A conditioner that uses a pretransform's encoder for conditioning
442
-
443
- Args:
444
- pretransform: an instantiated pretransform to use for conditioning
445
- output_dim: the dimension of the output embeddings
446
- """
447
- def __init__(self, pretransform: Pretransform, output_dim: int):
448
- super().__init__(pretransform.encoded_channels, output_dim)
449
-
450
- self.pretransform = pretransform
451
-
452
- def forward(self, audio: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
453
-
454
- self.pretransform.to(device)
455
- self.proj_out.to(device)
456
-
457
- if isinstance(audio, list) or isinstance(audio, tuple):
458
- audio = torch.cat(audio, dim=0)
459
-
460
- # Convert audio to pretransform input channels
461
- audio = set_audio_channels(audio, self.pretransform.io_channels)
462
-
463
- latents = self.pretransform.encode(audio)
464
-
465
- latents = self.proj_out(latents)
466
-
467
- return [latents, torch.ones(latents.shape[0], latents.shape[2]).to(latents.device)]
468
-
469
- class MultiConditioner(nn.Module):
470
- """
471
- A module that applies multiple conditioners to an input dictionary based on the keys
472
-
473
- Args:
474
- conditioners: a dictionary of conditioners with keys corresponding to the keys of the conditioning input dictionary (e.g. "prompt")
475
- default_keys: a dictionary of default keys to use if the key is not in the input dictionary (e.g. {"prompt_t5": "prompt"})
476
- """
477
- def __init__(self, conditioners: tp.Dict[str, Conditioner], default_keys: tp.Dict[str, str] = {}):
478
- super().__init__()
479
-
480
- self.conditioners = nn.ModuleDict(conditioners)
481
- self.default_keys = default_keys
482
-
483
- def forward(self, batch_metadata: tp.List[tp.Dict[str, tp.Any]], device: tp.Union[torch.device, str]) -> tp.Dict[str, tp.Any]:
484
- output = {}
485
-
486
- for key, conditioner in self.conditioners.items():
487
- condition_key = key
488
-
489
- conditioner_inputs = []
490
-
491
- for x in batch_metadata:
492
-
493
- if condition_key not in x:
494
- if condition_key in self.default_keys:
495
- condition_key = self.default_keys[condition_key]
496
- else:
497
- raise ValueError(f"Conditioner key {condition_key} not found in batch metadata")
498
-
499
- #Unwrap the condition info if it's a single-element list or tuple, this is to support collation functions that wrap everything in a list
500
- if isinstance(x[condition_key], list) or isinstance(x[condition_key], tuple) and len(x[condition_key]) == 1:
501
- conditioner_input = x[condition_key][0]
502
-
503
- else:
504
- conditioner_input = x[condition_key]
505
-
506
- conditioner_inputs.append(conditioner_input)
507
-
508
- output[key] = conditioner(conditioner_inputs, device)
509
-
510
- return output
511
-
512
- def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.Any]) -> MultiConditioner:
513
- """
514
- Create a MultiConditioner from a conditioning config dictionary
515
-
516
- Args:
517
- config: the conditioning config dictionary
518
- device: the device to put the conditioners on
519
- """
520
- conditioners = {}
521
- cond_dim = config["cond_dim"]
522
-
523
- default_keys = config.get("default_keys", {})
524
-
525
- for conditioner_info in config["configs"]:
526
- id = conditioner_info["id"]
527
-
528
- conditioner_type = conditioner_info["type"]
529
-
530
- conditioner_config = {"output_dim": cond_dim}
531
-
532
- conditioner_config.update(conditioner_info["config"])
533
-
534
- if conditioner_type == "t5":
535
- conditioners[id] = T5Conditioner(**conditioner_config)
536
- elif conditioner_type == "clap_text":
537
- conditioners[id] = CLAPTextConditioner(**conditioner_config)
538
- elif conditioner_type == "clap_audio":
539
- conditioners[id] = CLAPAudioConditioner(**conditioner_config)
540
- elif conditioner_type == "int":
541
- conditioners[id] = IntConditioner(**conditioner_config)
542
- elif conditioner_type == "number":
543
- conditioners[id] = NumberConditioner(**conditioner_config)
544
- elif conditioner_type == "phoneme":
545
- conditioners[id] = PhonemeConditioner(**conditioner_config)
546
- elif conditioner_type == "lut":
547
- conditioners[id] = TokenizerLUTConditioner(**conditioner_config)
548
- elif conditioner_type == "pretransform":
549
- sample_rate = conditioner_config.pop("sample_rate", None)
550
- assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners"
551
-
552
- pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate)
553
-
554
- if conditioner_config.get("pretransform_ckpt_path", None) is not None:
555
- pretransform.load_state_dict(load_ckpt_state_dict(conditioner_config.pop("pretransform_ckpt_path")))
556
-
557
- conditioners[id] = PretransformConditioner(pretransform, **conditioner_config)
558
- else:
559
- raise ValueError(f"Unknown conditioner type: {conditioner_type}")
560
-
561
- return MultiConditioner(conditioners, default_keys=default_keys)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/models/diffusion.py DELETED
@@ -1,701 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn import functional as F
4
- from functools import partial
5
- import numpy as np
6
- import typing as tp
7
-
8
- from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes
9
- from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config
10
- from .dit import DiffusionTransformer
11
- from .factory import create_pretransform_from_config
12
- from .pretransforms import Pretransform
13
- from ..inference.generation import generate_diffusion_cond
14
-
15
- from .adp import UNetCFG1d, UNet1d
16
-
17
- from time import time
18
-
19
- class Profiler:
20
-
21
- def __init__(self):
22
- self.ticks = [[time(), None]]
23
-
24
- def tick(self, msg):
25
- self.ticks.append([time(), msg])
26
-
27
- def __repr__(self):
28
- rep = 80 * "=" + "\n"
29
- for i in range(1, len(self.ticks)):
30
- msg = self.ticks[i][1]
31
- ellapsed = self.ticks[i][0] - self.ticks[i - 1][0]
32
- rep += msg + f": {ellapsed*1000:.2f}ms\n"
33
- rep += 80 * "=" + "\n\n\n"
34
- return rep
35
-
36
- class DiffusionModel(nn.Module):
37
- def __init__(self, *args, **kwargs):
38
- super().__init__(*args, **kwargs)
39
-
40
- def forward(self, x, t, **kwargs):
41
- raise NotImplementedError()
42
-
43
- class DiffusionModelWrapper(nn.Module):
44
- def __init__(
45
- self,
46
- model: DiffusionModel,
47
- io_channels,
48
- sample_size,
49
- sample_rate,
50
- min_input_length,
51
- pretransform: tp.Optional[Pretransform] = None,
52
- ):
53
- super().__init__()
54
- self.io_channels = io_channels
55
- self.sample_size = sample_size
56
- self.sample_rate = sample_rate
57
- self.min_input_length = min_input_length
58
-
59
- self.model = model
60
-
61
- if pretransform is not None:
62
- self.pretransform = pretransform
63
- else:
64
- self.pretransform = None
65
-
66
- def forward(self, x, t, **kwargs):
67
- return self.model(x, t, **kwargs)
68
-
69
- class ConditionedDiffusionModel(nn.Module):
70
- def __init__(self,
71
- *args,
72
- supports_cross_attention: bool = False,
73
- supports_input_concat: bool = False,
74
- supports_global_cond: bool = False,
75
- supports_prepend_cond: bool = False,
76
- **kwargs):
77
- super().__init__(*args, **kwargs)
78
- self.supports_cross_attention = supports_cross_attention
79
- self.supports_input_concat = supports_input_concat
80
- self.supports_global_cond = supports_global_cond
81
- self.supports_prepend_cond = supports_prepend_cond
82
-
83
- def forward(self,
84
- x: torch.Tensor,
85
- t: torch.Tensor,
86
- cross_attn_cond: torch.Tensor = None,
87
- cross_attn_mask: torch.Tensor = None,
88
- input_concat_cond: torch.Tensor = None,
89
- global_embed: torch.Tensor = None,
90
- prepend_cond: torch.Tensor = None,
91
- prepend_cond_mask: torch.Tensor = None,
92
- cfg_scale: float = 1.0,
93
- cfg_dropout_prob: float = 0.0,
94
- batch_cfg: bool = False,
95
- rescale_cfg: bool = False,
96
- **kwargs):
97
- raise NotImplementedError()
98
-
99
- class ConditionedDiffusionModelWrapper(nn.Module):
100
- """
101
- A diffusion model that takes in conditioning
102
- """
103
- def __init__(
104
- self,
105
- model: ConditionedDiffusionModel,
106
- conditioner: MultiConditioner,
107
- io_channels,
108
- sample_rate,
109
- min_input_length: int,
110
- diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
111
- pretransform: tp.Optional[Pretransform] = None,
112
- cross_attn_cond_ids: tp.List[str] = [],
113
- global_cond_ids: tp.List[str] = [],
114
- input_concat_ids: tp.List[str] = [],
115
- prepend_cond_ids: tp.List[str] = [],
116
- ):
117
- super().__init__()
118
-
119
- self.model = model
120
- self.conditioner = conditioner
121
- self.io_channels = io_channels
122
- self.sample_rate = sample_rate
123
- self.diffusion_objective = diffusion_objective
124
- self.pretransform = pretransform
125
- self.cross_attn_cond_ids = cross_attn_cond_ids
126
- self.global_cond_ids = global_cond_ids
127
- self.input_concat_ids = input_concat_ids
128
- self.prepend_cond_ids = prepend_cond_ids
129
- self.min_input_length = min_input_length
130
-
131
- def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False):
132
- cross_attention_input = None
133
- cross_attention_masks = None
134
- global_cond = None
135
- input_concat_cond = None
136
- prepend_cond = None
137
- prepend_cond_mask = None
138
-
139
- if len(self.cross_attn_cond_ids) > 0:
140
- # Concatenate all cross-attention inputs over the sequence dimension
141
- # Assumes that the cross-attention inputs are of shape (batch, seq, channels)
142
- cross_attention_input = []
143
- cross_attention_masks = []
144
-
145
- for key in self.cross_attn_cond_ids:
146
- cross_attn_in, cross_attn_mask = conditioning_tensors[key]
147
-
148
- # Add sequence dimension if it's not there
149
- if len(cross_attn_in.shape) == 2:
150
- cross_attn_in = cross_attn_in.unsqueeze(1)
151
- cross_attn_mask = cross_attn_mask.unsqueeze(1)
152
-
153
- cross_attention_input.append(cross_attn_in)
154
- cross_attention_masks.append(cross_attn_mask)
155
-
156
- cross_attention_input = torch.cat(cross_attention_input, dim=1)
157
- cross_attention_masks = torch.cat(cross_attention_masks, dim=1)
158
-
159
- if len(self.global_cond_ids) > 0:
160
- # Concatenate all global conditioning inputs over the channel dimension
161
- # Assumes that the global conditioning inputs are of shape (batch, channels)
162
- global_conds = []
163
- for key in self.global_cond_ids:
164
- global_cond_input = conditioning_tensors[key][0]
165
-
166
- global_conds.append(global_cond_input)
167
-
168
- # Concatenate over the channel dimension
169
- global_cond = torch.cat(global_conds, dim=-1)
170
-
171
- if len(global_cond.shape) == 3:
172
- global_cond = global_cond.squeeze(1)
173
-
174
- if len(self.input_concat_ids) > 0:
175
- # Concatenate all input concat conditioning inputs over the channel dimension
176
- # Assumes that the input concat conditioning inputs are of shape (batch, channels, seq)
177
- input_concat_cond = torch.cat([conditioning_tensors[key][0] for key in self.input_concat_ids], dim=1)
178
-
179
- if len(self.prepend_cond_ids) > 0:
180
- # Concatenate all prepend conditioning inputs over the sequence dimension
181
- # Assumes that the prepend conditioning inputs are of shape (batch, seq, channels)
182
- prepend_conds = []
183
- prepend_cond_masks = []
184
-
185
- for key in self.prepend_cond_ids:
186
- prepend_cond_input, prepend_cond_mask = conditioning_tensors[key]
187
- prepend_conds.append(prepend_cond_input)
188
- prepend_cond_masks.append(prepend_cond_mask)
189
-
190
- prepend_cond = torch.cat(prepend_conds, dim=1)
191
- prepend_cond_mask = torch.cat(prepend_cond_masks, dim=1)
192
-
193
- if negative:
194
- return {
195
- "negative_cross_attn_cond": cross_attention_input,
196
- "negative_cross_attn_mask": cross_attention_masks,
197
- "negative_global_cond": global_cond,
198
- "negative_input_concat_cond": input_concat_cond
199
- }
200
- else:
201
- return {
202
- "cross_attn_cond": cross_attention_input,
203
- "cross_attn_mask": cross_attention_masks,
204
- "global_cond": global_cond,
205
- "input_concat_cond": input_concat_cond,
206
- "prepend_cond": prepend_cond,
207
- "prepend_cond_mask": prepend_cond_mask
208
- }
209
-
210
- def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
211
- return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs)
212
-
213
- def generate(self, *args, **kwargs):
214
- return generate_diffusion_cond(self, *args, **kwargs)
215
-
216
- class UNetCFG1DWrapper(ConditionedDiffusionModel):
217
- def __init__(
218
- self,
219
- *args,
220
- **kwargs
221
- ):
222
- super().__init__(supports_cross_attention=True, supports_global_cond=True, supports_input_concat=True)
223
-
224
- self.model = UNetCFG1d(*args, **kwargs)
225
-
226
- with torch.no_grad():
227
- for param in self.model.parameters():
228
- param *= 0.5
229
-
230
- def forward(self,
231
- x,
232
- t,
233
- cross_attn_cond=None,
234
- cross_attn_mask=None,
235
- input_concat_cond=None,
236
- global_cond=None,
237
- cfg_scale=1.0,
238
- cfg_dropout_prob: float = 0.0,
239
- batch_cfg: bool = False,
240
- rescale_cfg: bool = False,
241
- negative_cross_attn_cond=None,
242
- negative_cross_attn_mask=None,
243
- negative_global_cond=None,
244
- negative_input_concat_cond=None,
245
- prepend_cond=None,
246
- prepend_cond_mask=None,
247
- **kwargs):
248
- p = Profiler()
249
-
250
- p.tick("start")
251
-
252
- channels_list = None
253
- if input_concat_cond is not None:
254
- channels_list = [input_concat_cond]
255
-
256
- outputs = self.model(
257
- x,
258
- t,
259
- embedding=cross_attn_cond,
260
- embedding_mask=cross_attn_mask,
261
- features=global_cond,
262
- channels_list=channels_list,
263
- embedding_scale=cfg_scale,
264
- embedding_mask_proba=cfg_dropout_prob,
265
- batch_cfg=batch_cfg,
266
- rescale_cfg=rescale_cfg,
267
- negative_embedding=negative_cross_attn_cond,
268
- negative_embedding_mask=negative_cross_attn_mask,
269
- **kwargs)
270
-
271
- p.tick("UNetCFG1D forward")
272
-
273
- #print(f"Profiler: {p}")
274
- return outputs
275
-
276
- class UNet1DCondWrapper(ConditionedDiffusionModel):
277
- def __init__(
278
- self,
279
- *args,
280
- **kwargs
281
- ):
282
- super().__init__(supports_cross_attention=False, supports_global_cond=True, supports_input_concat=True)
283
-
284
- self.model = UNet1d(*args, **kwargs)
285
-
286
- with torch.no_grad():
287
- for param in self.model.parameters():
288
- param *= 0.5
289
-
290
- def forward(self,
291
- x,
292
- t,
293
- input_concat_cond=None,
294
- global_cond=None,
295
- cross_attn_cond=None,
296
- cross_attn_mask=None,
297
- prepend_cond=None,
298
- prepend_cond_mask=None,
299
- cfg_scale=1.0,
300
- cfg_dropout_prob: float = 0.0,
301
- batch_cfg: bool = False,
302
- rescale_cfg: bool = False,
303
- negative_cross_attn_cond=None,
304
- negative_cross_attn_mask=None,
305
- negative_global_cond=None,
306
- negative_input_concat_cond=None,
307
- **kwargs):
308
-
309
- channels_list = None
310
- if input_concat_cond is not None:
311
-
312
- # Interpolate input_concat_cond to the same length as x
313
- if input_concat_cond.shape[2] != x.shape[2]:
314
- input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
315
-
316
- channels_list = [input_concat_cond]
317
-
318
- outputs = self.model(
319
- x,
320
- t,
321
- features=global_cond,
322
- channels_list=channels_list,
323
- **kwargs)
324
-
325
- return outputs
326
-
327
- class UNet1DUncondWrapper(DiffusionModel):
328
- def __init__(
329
- self,
330
- in_channels,
331
- *args,
332
- **kwargs
333
- ):
334
- super().__init__()
335
-
336
- self.model = UNet1d(in_channels=in_channels, *args, **kwargs)
337
-
338
- self.io_channels = in_channels
339
-
340
- with torch.no_grad():
341
- for param in self.model.parameters():
342
- param *= 0.5
343
-
344
- def forward(self, x, t, **kwargs):
345
- return self.model(x, t, **kwargs)
346
-
347
- class DAU1DCondWrapper(ConditionedDiffusionModel):
348
- def __init__(
349
- self,
350
- *args,
351
- **kwargs
352
- ):
353
- super().__init__(supports_cross_attention=False, supports_global_cond=False, supports_input_concat=True)
354
-
355
- self.model = DiffusionAttnUnet1D(*args, **kwargs)
356
-
357
- with torch.no_grad():
358
- for param in self.model.parameters():
359
- param *= 0.5
360
-
361
- def forward(self,
362
- x,
363
- t,
364
- input_concat_cond=None,
365
- cross_attn_cond=None,
366
- cross_attn_mask=None,
367
- global_cond=None,
368
- cfg_scale=1.0,
369
- cfg_dropout_prob: float = 0.0,
370
- batch_cfg: bool = False,
371
- rescale_cfg: bool = False,
372
- negative_cross_attn_cond=None,
373
- negative_cross_attn_mask=None,
374
- negative_global_cond=None,
375
- negative_input_concat_cond=None,
376
- prepend_cond=None,
377
- **kwargs):
378
-
379
- return self.model(x, t, cond = input_concat_cond)
380
-
381
- class DiffusionAttnUnet1D(nn.Module):
382
- def __init__(
383
- self,
384
- io_channels = 2,
385
- depth=14,
386
- n_attn_layers = 6,
387
- channels = [128, 128, 256, 256] + [512] * 10,
388
- cond_dim = 0,
389
- cond_noise_aug = False,
390
- kernel_size = 5,
391
- learned_resample = False,
392
- strides = [2] * 13,
393
- conv_bias = True,
394
- use_snake = False
395
- ):
396
- super().__init__()
397
-
398
- self.cond_noise_aug = cond_noise_aug
399
-
400
- self.io_channels = io_channels
401
-
402
- if self.cond_noise_aug:
403
- self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
404
-
405
- self.timestep_embed = FourierFeatures(1, 16)
406
-
407
- attn_layer = depth - n_attn_layers
408
-
409
- strides = [1] + strides
410
-
411
- block = nn.Identity()
412
-
413
- conv_block = partial(ResConvBlock, kernel_size=kernel_size, conv_bias = conv_bias, use_snake=use_snake)
414
-
415
- for i in range(depth, 0, -1):
416
- c = channels[i - 1]
417
- stride = strides[i-1]
418
- if stride > 2 and not learned_resample:
419
- raise ValueError("Must have stride 2 without learned resampling")
420
-
421
- if i > 1:
422
- c_prev = channels[i - 2]
423
- add_attn = i >= attn_layer and n_attn_layers > 0
424
- block = SkipBlock(
425
- Downsample1d_2(c_prev, c_prev, stride) if (learned_resample or stride == 1) else Downsample1d("cubic"),
426
- conv_block(c_prev, c, c),
427
- SelfAttention1d(
428
- c, c // 32) if add_attn else nn.Identity(),
429
- conv_block(c, c, c),
430
- SelfAttention1d(
431
- c, c // 32) if add_attn else nn.Identity(),
432
- conv_block(c, c, c),
433
- SelfAttention1d(
434
- c, c // 32) if add_attn else nn.Identity(),
435
- block,
436
- conv_block(c * 2 if i != depth else c, c, c),
437
- SelfAttention1d(
438
- c, c // 32) if add_attn else nn.Identity(),
439
- conv_block(c, c, c),
440
- SelfAttention1d(
441
- c, c // 32) if add_attn else nn.Identity(),
442
- conv_block(c, c, c_prev),
443
- SelfAttention1d(c_prev, c_prev //
444
- 32) if add_attn else nn.Identity(),
445
- Upsample1d_2(c_prev, c_prev, stride) if learned_resample else Upsample1d(kernel="cubic")
446
- )
447
- else:
448
- cond_embed_dim = 16 if not self.cond_noise_aug else 32
449
- block = nn.Sequential(
450
- conv_block((io_channels + cond_dim) + cond_embed_dim, c, c),
451
- conv_block(c, c, c),
452
- conv_block(c, c, c),
453
- block,
454
- conv_block(c * 2, c, c),
455
- conv_block(c, c, c),
456
- conv_block(c, c, io_channels, is_last=True),
457
- )
458
- self.net = block
459
-
460
- with torch.no_grad():
461
- for param in self.net.parameters():
462
- param *= 0.5
463
-
464
- def forward(self, x, t, cond=None, cond_aug_scale=None):
465
-
466
- timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), x.shape)
467
-
468
- inputs = [x, timestep_embed]
469
-
470
- if cond is not None:
471
- if cond.shape[2] != x.shape[2]:
472
- cond = F.interpolate(cond, (x.shape[2], ), mode='linear', align_corners=False)
473
-
474
- if self.cond_noise_aug:
475
- # Get a random number between 0 and 1, uniformly sampled
476
- if cond_aug_scale is None:
477
- aug_level = self.rng.draw(cond.shape[0])[:, 0].to(cond)
478
- else:
479
- aug_level = torch.tensor([cond_aug_scale]).repeat([cond.shape[0]]).to(cond)
480
-
481
- # Add noise to the conditioning signal
482
- cond = cond + torch.randn_like(cond) * aug_level[:, None, None]
483
-
484
- # Get embedding for noise cond level, reusing timestamp_embed
485
- aug_level_embed = expand_to_planes(self.timestep_embed(aug_level[:, None]), x.shape)
486
-
487
- inputs.append(aug_level_embed)
488
-
489
- inputs.append(cond)
490
-
491
- outputs = self.net(torch.cat(inputs, dim=1))
492
-
493
- return outputs
494
-
495
- class DiTWrapper(ConditionedDiffusionModel):
496
- def __init__(
497
- self,
498
- *args,
499
- **kwargs
500
- ):
501
- super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False)
502
-
503
- self.model = DiffusionTransformer(*args, **kwargs)
504
-
505
- with torch.no_grad():
506
- for param in self.model.parameters():
507
- param *= 0.5
508
-
509
- def forward(self,
510
- x,
511
- t,
512
- cross_attn_cond=None,
513
- cross_attn_mask=None,
514
- negative_cross_attn_cond=None,
515
- negative_cross_attn_mask=None,
516
- input_concat_cond=None,
517
- negative_input_concat_cond=None,
518
- global_cond=None,
519
- negative_global_cond=None,
520
- prepend_cond=None,
521
- prepend_cond_mask=None,
522
- cfg_scale=1.0,
523
- cfg_dropout_prob: float = 0.0,
524
- batch_cfg: bool = True,
525
- rescale_cfg: bool = False,
526
- scale_phi: float = 0.0,
527
- **kwargs):
528
-
529
- assert batch_cfg, "batch_cfg must be True for DiTWrapper"
530
- #assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper"
531
-
532
- return self.model(
533
- x,
534
- t,
535
- cross_attn_cond=cross_attn_cond,
536
- cross_attn_cond_mask=cross_attn_mask,
537
- negative_cross_attn_cond=negative_cross_attn_cond,
538
- negative_cross_attn_mask=negative_cross_attn_mask,
539
- input_concat_cond=input_concat_cond,
540
- prepend_cond=prepend_cond,
541
- prepend_cond_mask=prepend_cond_mask,
542
- cfg_scale=cfg_scale,
543
- cfg_dropout_prob=cfg_dropout_prob,
544
- scale_phi=scale_phi,
545
- global_embed=global_cond,
546
- **kwargs)
547
-
548
- class DiTUncondWrapper(DiffusionModel):
549
- def __init__(
550
- self,
551
- in_channels,
552
- *args,
553
- **kwargs
554
- ):
555
- super().__init__()
556
-
557
- self.model = DiffusionTransformer(io_channels=in_channels, *args, **kwargs)
558
-
559
- self.io_channels = in_channels
560
-
561
- with torch.no_grad():
562
- for param in self.model.parameters():
563
- param *= 0.5
564
-
565
- def forward(self, x, t, **kwargs):
566
- return self.model(x, t, **kwargs)
567
-
568
- def create_diffusion_uncond_from_config(config: tp.Dict[str, tp.Any]):
569
- diffusion_uncond_config = config["model"]
570
-
571
- model_type = diffusion_uncond_config.get('type', None)
572
-
573
- diffusion_config = diffusion_uncond_config.get('config', {})
574
-
575
- assert model_type is not None, "Must specify model type in config"
576
-
577
- pretransform = diffusion_uncond_config.get("pretransform", None)
578
-
579
- sample_size = config.get("sample_size", None)
580
- assert sample_size is not None, "Must specify sample size in config"
581
-
582
- sample_rate = config.get("sample_rate", None)
583
- assert sample_rate is not None, "Must specify sample rate in config"
584
-
585
- if pretransform is not None:
586
- pretransform = create_pretransform_from_config(pretransform, sample_rate)
587
- min_input_length = pretransform.downsampling_ratio
588
- else:
589
- min_input_length = 1
590
-
591
- if model_type == 'DAU1d':
592
-
593
- model = DiffusionAttnUnet1D(
594
- **diffusion_config
595
- )
596
-
597
- elif model_type == "adp_uncond_1d":
598
-
599
- model = UNet1DUncondWrapper(
600
- **diffusion_config
601
- )
602
-
603
- elif model_type == "dit":
604
- model = DiTUncondWrapper(
605
- **diffusion_config
606
- )
607
-
608
- else:
609
- raise NotImplementedError(f'Unknown model type: {model_type}')
610
-
611
- return DiffusionModelWrapper(model,
612
- io_channels=model.io_channels,
613
- sample_size=sample_size,
614
- sample_rate=sample_rate,
615
- pretransform=pretransform,
616
- min_input_length=min_input_length)
617
-
618
- def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
619
-
620
- model_config = config["model"]
621
-
622
- model_type = config["model_type"]
623
-
624
- diffusion_config = model_config.get('diffusion', None)
625
- assert diffusion_config is not None, "Must specify diffusion config"
626
-
627
- diffusion_model_type = diffusion_config.get('type', None)
628
- assert diffusion_model_type is not None, "Must specify diffusion model type"
629
-
630
- diffusion_model_config = diffusion_config.get('config', None)
631
- assert diffusion_model_config is not None, "Must specify diffusion model config"
632
-
633
- if diffusion_model_type == 'adp_cfg_1d':
634
- diffusion_model = UNetCFG1DWrapper(**diffusion_model_config)
635
- elif diffusion_model_type == 'adp_1d':
636
- diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
637
- elif diffusion_model_type == 'dit':
638
- diffusion_model = DiTWrapper(**diffusion_model_config)
639
-
640
- io_channels = model_config.get('io_channels', None)
641
- assert io_channels is not None, "Must specify io_channels in model config"
642
-
643
- sample_rate = config.get('sample_rate', None)
644
- assert sample_rate is not None, "Must specify sample_rate in config"
645
-
646
- diffusion_objective = diffusion_config.get('diffusion_objective', 'v')
647
-
648
- conditioning_config = model_config.get('conditioning', None)
649
-
650
- conditioner = None
651
- if conditioning_config is not None:
652
- conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config)
653
-
654
- cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', [])
655
- global_cond_ids = diffusion_config.get('global_cond_ids', [])
656
- input_concat_ids = diffusion_config.get('input_concat_ids', [])
657
- prepend_cond_ids = diffusion_config.get('prepend_cond_ids', [])
658
-
659
- pretransform = model_config.get("pretransform", None)
660
-
661
- if pretransform is not None:
662
- pretransform = create_pretransform_from_config(pretransform, sample_rate)
663
- min_input_length = pretransform.downsampling_ratio
664
- else:
665
- min_input_length = 1
666
-
667
- if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d":
668
- min_input_length *= np.prod(diffusion_model_config["factors"])
669
- elif diffusion_model_type == "dit":
670
- min_input_length *= diffusion_model.model.patch_size
671
-
672
- # Get the proper wrapper class
673
-
674
- extra_kwargs = {}
675
-
676
- if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint":
677
- wrapper_fn = ConditionedDiffusionModelWrapper
678
-
679
- extra_kwargs["diffusion_objective"] = diffusion_objective
680
-
681
- elif model_type == "diffusion_prior":
682
- prior_type = model_config.get("prior_type", None)
683
- assert prior_type is not None, "Must specify prior_type in diffusion prior model config"
684
-
685
- if prior_type == "mono_stereo":
686
- from .diffusion_prior import MonoToStereoDiffusionPrior
687
- wrapper_fn = MonoToStereoDiffusionPrior
688
-
689
- return wrapper_fn(
690
- diffusion_model,
691
- conditioner,
692
- min_input_length=min_input_length,
693
- sample_rate=sample_rate,
694
- cross_attn_cond_ids=cross_attention_ids,
695
- global_cond_ids=global_cond_ids,
696
- input_concat_ids=input_concat_ids,
697
- prepend_cond_ids=prepend_cond_ids,
698
- pretransform=pretransform,
699
- io_channels=io_channels,
700
- **extra_kwargs
701
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/models/diffusion_prior.py DELETED
@@ -1,79 +0,0 @@
1
- from enum import Enum
2
- import typing as tp
3
-
4
- from .diffusion import ConditionedDiffusionModelWrapper
5
- from ..inference.generation import generate_diffusion_cond
6
- from ..inference.utils import prepare_audio
7
-
8
- import torch
9
- from torch.nn import functional as F
10
- from torchaudio import transforms as T
11
-
12
- # Define prior types enum
13
- class PriorType(Enum):
14
- MonoToStereo = 1
15
-
16
- class DiffusionPrior(ConditionedDiffusionModelWrapper):
17
- def __init__(self, *args, prior_type: PriorType=None, **kwargs):
18
- super().__init__(*args, **kwargs)
19
- self.prior_type = prior_type
20
-
21
- class MonoToStereoDiffusionPrior(DiffusionPrior):
22
- def __init__(self, *args, **kwargs):
23
- super().__init__(*args, prior_type=PriorType.MonoToStereo, **kwargs)
24
-
25
- def stereoize(
26
- self,
27
- audio: torch.Tensor, # (batch, channels, time)
28
- in_sr: int,
29
- steps: int,
30
- sampler_kwargs: dict = {},
31
- ):
32
- """
33
- Generate stereo audio from mono audio using a pre-trained diffusion prior
34
-
35
- Args:
36
- audio: The mono audio to convert to stereo
37
- in_sr: The sample rate of the input audio
38
- steps: The number of diffusion steps to run
39
- sampler_kwargs: Keyword arguments to pass to the diffusion sampler
40
- """
41
-
42
- device = audio.device
43
-
44
- sample_rate = self.sample_rate
45
-
46
- # Resample input audio if necessary
47
- if in_sr != sample_rate:
48
- resample_tf = T.Resample(in_sr, sample_rate).to(audio.device)
49
- audio = resample_tf(audio)
50
-
51
- audio_length = audio.shape[-1]
52
-
53
- # Pad input audio to be compatible with the model
54
- min_length = self.min_input_length
55
- padded_input_length = audio_length + (min_length - (audio_length % min_length)) % min_length
56
-
57
- # Pad input audio to be compatible with the model
58
- if padded_input_length > audio_length:
59
- audio = F.pad(audio, (0, padded_input_length - audio_length))
60
-
61
- # Make audio mono, duplicate to stereo
62
- dual_mono = audio.mean(1, keepdim=True).repeat(1, 2, 1)
63
-
64
- if self.pretransform is not None:
65
- dual_mono = self.pretransform.encode(dual_mono)
66
-
67
- conditioning = {"source": [dual_mono]}
68
-
69
- stereo_audio = generate_diffusion_cond(
70
- self,
71
- conditioning_tensors=conditioning,
72
- steps=steps,
73
- sample_size=padded_input_length,
74
- sample_rate=sample_rate,
75
- device=device,
76
- **sampler_kwargs,
77
- )
78
-
79
- return stereo_audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/models/discriminators.py DELETED
@@ -1,546 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import numpy as np
5
- from functools import reduce
6
- import typing as tp
7
- from einops import rearrange
8
- from audiotools import AudioSignal, STFTParams
9
- from dac.model.discriminator import WNConv1d, WNConv2d
10
-
11
- def get_hinge_losses(score_real, score_fake):
12
- gen_loss = -score_fake.mean()
13
- dis_loss = torch.relu(1 - score_real).mean() + torch.relu(1 + score_fake).mean()
14
- return dis_loss, gen_loss
15
-
16
- class EncodecDiscriminator(nn.Module):
17
-
18
- def __init__(self, *args, **kwargs):
19
- super().__init__()
20
-
21
- from encodec.msstftd import MultiScaleSTFTDiscriminator
22
-
23
- self.discriminators = MultiScaleSTFTDiscriminator(*args, **kwargs)
24
-
25
- def forward(self, x):
26
- logits, features = self.discriminators(x)
27
- return logits, features
28
-
29
- def loss(self, x, y):
30
- feature_matching_distance = 0.
31
- logits_true, feature_true = self.forward(x)
32
- logits_fake, feature_fake = self.forward(y)
33
-
34
- dis_loss = torch.tensor(0.)
35
- adv_loss = torch.tensor(0.)
36
-
37
- for i, (scale_true, scale_fake) in enumerate(zip(feature_true, feature_fake)):
38
-
39
- feature_matching_distance = feature_matching_distance + sum(
40
- map(
41
- lambda x, y: abs(x - y).mean(),
42
- scale_true,
43
- scale_fake,
44
- )) / len(scale_true)
45
-
46
- _dis, _adv = get_hinge_losses(
47
- logits_true[i],
48
- logits_fake[i],
49
- )
50
-
51
- dis_loss = dis_loss + _dis
52
- adv_loss = adv_loss + _adv
53
-
54
- return dis_loss, adv_loss, feature_matching_distance
55
-
56
- # Discriminators from oobleck
57
-
58
- IndividualDiscriminatorOut = tp.Tuple[torch.Tensor, tp.Sequence[torch.Tensor]]
59
-
60
- TensorDict = tp.Dict[str, torch.Tensor]
61
-
62
- class SharedDiscriminatorConvNet(nn.Module):
63
-
64
- def __init__(
65
- self,
66
- in_size: int,
67
- convolution: tp.Union[nn.Conv1d, nn.Conv2d],
68
- out_size: int = 1,
69
- capacity: int = 32,
70
- n_layers: int = 4,
71
- kernel_size: int = 15,
72
- stride: int = 4,
73
- activation: tp.Callable[[], nn.Module] = lambda: nn.SiLU(),
74
- normalization: tp.Callable[[nn.Module], nn.Module] = torch.nn.utils.weight_norm,
75
- ) -> None:
76
- super().__init__()
77
- channels = [in_size]
78
- channels += list(capacity * 2**np.arange(n_layers))
79
-
80
- if isinstance(stride, int):
81
- stride = n_layers * [stride]
82
-
83
- net = []
84
- for i in range(n_layers):
85
- if isinstance(kernel_size, int):
86
- pad = kernel_size // 2
87
- s = stride[i]
88
- else:
89
- pad = kernel_size[0] // 2
90
- s = (stride[i], 1)
91
-
92
- net.append(
93
- normalization(
94
- convolution(
95
- channels[i],
96
- channels[i + 1],
97
- kernel_size,
98
- stride=s,
99
- padding=pad,
100
- )))
101
- net.append(activation())
102
-
103
- net.append(convolution(channels[-1], out_size, 1))
104
-
105
- self.net = nn.ModuleList(net)
106
-
107
- def forward(self, x) -> IndividualDiscriminatorOut:
108
- features = []
109
- for layer in self.net:
110
- x = layer(x)
111
- if isinstance(layer, nn.modules.conv._ConvNd):
112
- features.append(x)
113
- score = x.reshape(x.shape[0], -1).mean(-1)
114
- return score, features
115
-
116
-
117
- class MultiScaleDiscriminator(nn.Module):
118
-
119
- def __init__(self,
120
- in_channels: int,
121
- n_scales: int,
122
- **conv_kwargs) -> None:
123
- super().__init__()
124
- layers = []
125
- for _ in range(n_scales):
126
- layers.append(SharedDiscriminatorConvNet(in_channels, nn.Conv1d, **conv_kwargs))
127
- self.layers = nn.ModuleList(layers)
128
-
129
- def forward(self, x: torch.Tensor) -> IndividualDiscriminatorOut:
130
- score = 0
131
- features = []
132
- for layer in self.layers:
133
- s, f = layer(x)
134
- score = score + s
135
- features.extend(f)
136
- x = nn.functional.avg_pool1d(x, 2)
137
- return score, features
138
-
139
- class MultiPeriodDiscriminator(nn.Module):
140
-
141
- def __init__(self,
142
- in_channels: int,
143
- periods: tp.Sequence[int],
144
- **conv_kwargs) -> None:
145
- super().__init__()
146
- layers = []
147
- self.periods = periods
148
-
149
- for _ in periods:
150
- layers.append(SharedDiscriminatorConvNet(in_channels, nn.Conv2d, **conv_kwargs))
151
-
152
- self.layers = nn.ModuleList(layers)
153
-
154
- def forward(self, x: torch.Tensor) -> IndividualDiscriminatorOut:
155
- score = 0
156
- features = []
157
- for layer, n in zip(self.layers, self.periods):
158
- s, f = layer(self.fold(x, n))
159
- score = score + s
160
- features.extend(f)
161
- return score, features
162
-
163
- def fold(self, x: torch.Tensor, n: int) -> torch.Tensor:
164
- pad = (n - (x.shape[-1] % n)) % n
165
- x = nn.functional.pad(x, (0, pad))
166
- return x.reshape(*x.shape[:2], -1, n)
167
-
168
-
169
- class MultiDiscriminator(nn.Module):
170
- """
171
- Individual discriminators should take a single tensor as input (NxB C T) and
172
- return a tuple composed of a score tensor (NxB) and a Sequence of Features
173
- Sequence[NxB C' T'].
174
- """
175
-
176
- def __init__(self, discriminator_list: tp.Sequence[nn.Module],
177
- keys: tp.Sequence[str]) -> None:
178
- super().__init__()
179
- self.discriminators = nn.ModuleList(discriminator_list)
180
- self.keys = keys
181
-
182
- def unpack_tensor_to_dict(self, features: torch.Tensor) -> TensorDict:
183
- features = features.chunk(len(self.keys), 0)
184
- return {k: features[i] for i, k in enumerate(self.keys)}
185
-
186
- @staticmethod
187
- def concat_dicts(dict_a, dict_b):
188
- out_dict = {}
189
- keys = set(list(dict_a.keys()) + list(dict_b.keys()))
190
- for k in keys:
191
- out_dict[k] = []
192
- if k in dict_a:
193
- if isinstance(dict_a[k], list):
194
- out_dict[k].extend(dict_a[k])
195
- else:
196
- out_dict[k].append(dict_a[k])
197
- if k in dict_b:
198
- if isinstance(dict_b[k], list):
199
- out_dict[k].extend(dict_b[k])
200
- else:
201
- out_dict[k].append(dict_b[k])
202
- return out_dict
203
-
204
- @staticmethod
205
- def sum_dicts(dict_a, dict_b):
206
- out_dict = {}
207
- keys = set(list(dict_a.keys()) + list(dict_b.keys()))
208
- for k in keys:
209
- out_dict[k] = 0.
210
- if k in dict_a:
211
- out_dict[k] = out_dict[k] + dict_a[k]
212
- if k in dict_b:
213
- out_dict[k] = out_dict[k] + dict_b[k]
214
- return out_dict
215
-
216
- def forward(self, inputs: TensorDict) -> TensorDict:
217
- discriminator_input = torch.cat([inputs[k] for k in self.keys], 0)
218
- all_scores = []
219
- all_features = []
220
-
221
- for discriminator in self.discriminators:
222
- score, features = discriminator(discriminator_input)
223
- scores = self.unpack_tensor_to_dict(score)
224
- scores = {f"score_{k}": scores[k] for k in scores.keys()}
225
- all_scores.append(scores)
226
-
227
- features = map(self.unpack_tensor_to_dict, features)
228
- features = reduce(self.concat_dicts, features)
229
- features = {f"features_{k}": features[k] for k in features.keys()}
230
- all_features.append(features)
231
-
232
- all_scores = reduce(self.sum_dicts, all_scores)
233
- all_features = reduce(self.concat_dicts, all_features)
234
-
235
- inputs.update(all_scores)
236
- inputs.update(all_features)
237
-
238
- return inputs
239
-
240
- class OobleckDiscriminator(nn.Module):
241
-
242
- def __init__(
243
- self,
244
- in_channels=1,
245
- ):
246
- super().__init__()
247
-
248
- multi_scale_discriminator = MultiScaleDiscriminator(
249
- in_channels=in_channels,
250
- n_scales=3,
251
- )
252
-
253
- multi_period_discriminator = MultiPeriodDiscriminator(
254
- in_channels=in_channels,
255
- periods=[2, 3, 5, 7, 11]
256
- )
257
-
258
- # multi_resolution_discriminator = MultiScaleSTFTDiscriminator(
259
- # filters=32,
260
- # in_channels = in_channels,
261
- # out_channels = 1,
262
- # n_ffts = [2048, 1024, 512, 256, 128],
263
- # hop_lengths = [512, 256, 128, 64, 32],
264
- # win_lengths = [2048, 1024, 512, 256, 128]
265
- # )
266
-
267
- self.multi_discriminator = MultiDiscriminator(
268
- [multi_scale_discriminator, multi_period_discriminator], #, multi_resolution_discriminator],
269
- ["reals", "fakes"]
270
- )
271
-
272
- def loss(self, reals, fakes):
273
- inputs = {
274
- "reals": reals,
275
- "fakes": fakes,
276
- }
277
-
278
- inputs = self.multi_discriminator(inputs)
279
-
280
- scores_real = inputs["score_reals"]
281
- scores_fake = inputs["score_fakes"]
282
-
283
- features_real = inputs["features_reals"]
284
- features_fake = inputs["features_fakes"]
285
-
286
- dis_loss, gen_loss = get_hinge_losses(scores_real, scores_fake)
287
-
288
- feature_matching_distance = torch.tensor(0.)
289
-
290
- for _, (scale_real, scale_fake) in enumerate(zip(features_real, features_fake)):
291
-
292
- feature_matching_distance = feature_matching_distance + sum(
293
- map(
294
- lambda real, fake: abs(real - fake).mean(),
295
- scale_real,
296
- scale_fake,
297
- )) / len(scale_real)
298
-
299
- return dis_loss, gen_loss, feature_matching_distance
300
-
301
-
302
- ## Discriminators from Descript Audio Codec repo
303
- ## Copied and modified under MIT license, see LICENSES/LICENSE_DESCRIPT.txt
304
- class MPD(nn.Module):
305
- def __init__(self, period, channels=1):
306
- super().__init__()
307
-
308
- self.period = period
309
- self.convs = nn.ModuleList(
310
- [
311
- WNConv2d(channels, 32, (5, 1), (3, 1), padding=(2, 0)),
312
- WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
313
- WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
314
- WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
315
- WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
316
- ]
317
- )
318
- self.conv_post = WNConv2d(
319
- 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
320
- )
321
-
322
- def pad_to_period(self, x):
323
- t = x.shape[-1]
324
- x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
325
- return x
326
-
327
- def forward(self, x):
328
- fmap = []
329
-
330
- x = self.pad_to_period(x)
331
- x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
332
-
333
- for layer in self.convs:
334
- x = layer(x)
335
- fmap.append(x)
336
-
337
- x = self.conv_post(x)
338
- fmap.append(x)
339
-
340
- return fmap
341
-
342
-
343
- class MSD(nn.Module):
344
- def __init__(self, rate: int = 1, sample_rate: int = 44100, channels=1):
345
- super().__init__()
346
-
347
- self.convs = nn.ModuleList(
348
- [
349
- WNConv1d(channels, 16, 15, 1, padding=7),
350
- WNConv1d(16, 64, 41, 4, groups=4, padding=20),
351
- WNConv1d(64, 256, 41, 4, groups=16, padding=20),
352
- WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
353
- WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
354
- WNConv1d(1024, 1024, 5, 1, padding=2),
355
- ]
356
- )
357
- self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
358
- self.sample_rate = sample_rate
359
- self.rate = rate
360
-
361
- def forward(self, x):
362
- x = AudioSignal(x, self.sample_rate)
363
- x.resample(self.sample_rate // self.rate)
364
- x = x.audio_data
365
-
366
- fmap = []
367
-
368
- for l in self.convs:
369
- x = l(x)
370
- fmap.append(x)
371
- x = self.conv_post(x)
372
- fmap.append(x)
373
-
374
- return fmap
375
-
376
-
377
- BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
378
-
379
-
380
- class MRD(nn.Module):
381
- def __init__(
382
- self,
383
- window_length: int,
384
- hop_factor: float = 0.25,
385
- sample_rate: int = 44100,
386
- bands: list = BANDS,
387
- channels: int = 1
388
- ):
389
- """Complex multi-band spectrogram discriminator.
390
- Parameters
391
- ----------
392
- window_length : int
393
- Window length of STFT.
394
- hop_factor : float, optional
395
- Hop factor of the STFT, defaults to ``0.25 * window_length``.
396
- sample_rate : int, optional
397
- Sampling rate of audio in Hz, by default 44100
398
- bands : list, optional
399
- Bands to run discriminator over.
400
- """
401
- super().__init__()
402
-
403
- self.window_length = window_length
404
- self.hop_factor = hop_factor
405
- self.sample_rate = sample_rate
406
- self.stft_params = STFTParams(
407
- window_length=window_length,
408
- hop_length=int(window_length * hop_factor),
409
- match_stride=True,
410
- )
411
-
412
- self.channels = channels
413
-
414
- n_fft = window_length // 2 + 1
415
- bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
416
- self.bands = bands
417
-
418
- ch = 32
419
- convs = lambda: nn.ModuleList(
420
- [
421
- WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
422
- WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
423
- WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
424
- WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
425
- WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
426
- ]
427
- )
428
- self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
429
- self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
430
-
431
- def spectrogram(self, x):
432
- x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
433
- x = torch.view_as_real(x.stft())
434
- x = rearrange(x, "b ch f t c -> (b ch) c t f", ch=self.channels)
435
- # Split into bands
436
- x_bands = [x[..., b[0] : b[1]] for b in self.bands]
437
- return x_bands
438
-
439
- def forward(self, x):
440
- x_bands = self.spectrogram(x)
441
- fmap = []
442
-
443
- x = []
444
- for band, stack in zip(x_bands, self.band_convs):
445
- for layer in stack:
446
- band = layer(band)
447
- fmap.append(band)
448
- x.append(band)
449
-
450
- x = torch.cat(x, dim=-1)
451
- x = self.conv_post(x)
452
- fmap.append(x)
453
-
454
- return fmap
455
-
456
-
457
- class DACDiscriminator(nn.Module):
458
- def __init__(
459
- self,
460
- channels: int = 1,
461
- rates: list = [],
462
- periods: list = [2, 3, 5, 7, 11],
463
- fft_sizes: list = [2048, 1024, 512],
464
- sample_rate: int = 44100,
465
- bands: list = BANDS,
466
- ):
467
- """Discriminator that combines multiple discriminators.
468
-
469
- Parameters
470
- ----------
471
- rates : list, optional
472
- sampling rates (in Hz) to run MSD at, by default []
473
- If empty, MSD is not used.
474
- periods : list, optional
475
- periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
476
- fft_sizes : list, optional
477
- Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
478
- sample_rate : int, optional
479
- Sampling rate of audio in Hz, by default 44100
480
- bands : list, optional
481
- Bands to run MRD at, by default `BANDS`
482
- """
483
- super().__init__()
484
- discs = []
485
- discs += [MPD(p, channels=channels) for p in periods]
486
- discs += [MSD(r, sample_rate=sample_rate, channels=channels) for r in rates]
487
- discs += [MRD(f, sample_rate=sample_rate, bands=bands, channels=channels) for f in fft_sizes]
488
- self.discriminators = nn.ModuleList(discs)
489
-
490
- def preprocess(self, y):
491
- # Remove DC offset
492
- y = y - y.mean(dim=-1, keepdims=True)
493
- # Peak normalize the volume of input audio
494
- y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
495
- return y
496
-
497
- def forward(self, x):
498
- x = self.preprocess(x)
499
- fmaps = [d(x) for d in self.discriminators]
500
- return fmaps
501
-
502
- class DACGANLoss(nn.Module):
503
- """
504
- Computes a discriminator loss, given a discriminator on
505
- generated waveforms/spectrograms compared to ground truth
506
- waveforms/spectrograms. Computes the loss for both the
507
- discriminator and the generator in separate functions.
508
- """
509
-
510
- def __init__(self, **discriminator_kwargs):
511
- super().__init__()
512
- self.discriminator = DACDiscriminator(**discriminator_kwargs)
513
-
514
- def forward(self, fake, real):
515
- d_fake = self.discriminator(fake)
516
- d_real = self.discriminator(real)
517
- return d_fake, d_real
518
-
519
- def discriminator_loss(self, fake, real):
520
- d_fake, d_real = self.forward(fake.clone().detach(), real)
521
-
522
- loss_d = 0
523
- for x_fake, x_real in zip(d_fake, d_real):
524
- loss_d += torch.mean(x_fake[-1] ** 2)
525
- loss_d += torch.mean((1 - x_real[-1]) ** 2)
526
- return loss_d
527
-
528
- def generator_loss(self, fake, real):
529
- d_fake, d_real = self.forward(fake, real)
530
-
531
- loss_g = 0
532
- for x_fake in d_fake:
533
- loss_g += torch.mean((1 - x_fake[-1]) ** 2)
534
-
535
- loss_feature = 0
536
-
537
- for i in range(len(d_fake)):
538
- for j in range(len(d_fake[i]) - 1):
539
- loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
540
- return loss_g, loss_feature
541
-
542
- def loss(self, fake, real):
543
- gen_loss, feature_distance = self.generator_loss(fake, real)
544
- dis_loss = self.discriminator_loss(fake, real)
545
-
546
- return dis_loss, gen_loss, feature_distance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/models/dit.py DELETED
@@ -1,379 +0,0 @@
1
- import typing as tp
2
-
3
- import torch
4
-
5
- from einops import rearrange
6
- from torch import nn
7
- from torch.nn import functional as F
8
- from x_transformers import ContinuousTransformerWrapper, Encoder
9
-
10
- from .blocks import FourierFeatures
11
- from .transformer import ContinuousTransformer
12
-
13
- class DiffusionTransformer(nn.Module):
14
- def __init__(self,
15
- io_channels=32,
16
- patch_size=1,
17
- embed_dim=768,
18
- cond_token_dim=0,
19
- project_cond_tokens=True,
20
- global_cond_dim=0,
21
- project_global_cond=True,
22
- input_concat_dim=0,
23
- prepend_cond_dim=0,
24
- depth=12,
25
- num_heads=8,
26
- transformer_type: tp.Literal["x-transformers", "continuous_transformer"] = "x-transformers",
27
- global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
28
- **kwargs):
29
-
30
- super().__init__()
31
-
32
- self.cond_token_dim = cond_token_dim
33
-
34
- # Timestep embeddings
35
- timestep_features_dim = 256
36
-
37
- self.timestep_features = FourierFeatures(1, timestep_features_dim)
38
-
39
- self.to_timestep_embed = nn.Sequential(
40
- nn.Linear(timestep_features_dim, embed_dim, bias=True),
41
- nn.SiLU(),
42
- nn.Linear(embed_dim, embed_dim, bias=True),
43
- )
44
-
45
- if cond_token_dim > 0:
46
- # Conditioning tokens
47
-
48
- cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
49
- self.to_cond_embed = nn.Sequential(
50
- nn.Linear(cond_token_dim, cond_embed_dim, bias=False),
51
- nn.SiLU(),
52
- nn.Linear(cond_embed_dim, cond_embed_dim, bias=False)
53
- )
54
- else:
55
- cond_embed_dim = 0
56
-
57
- if global_cond_dim > 0:
58
- # Global conditioning
59
- global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
60
- self.to_global_embed = nn.Sequential(
61
- nn.Linear(global_cond_dim, global_embed_dim, bias=False),
62
- nn.SiLU(),
63
- nn.Linear(global_embed_dim, global_embed_dim, bias=False)
64
- )
65
-
66
- if prepend_cond_dim > 0:
67
- # Prepend conditioning
68
- self.to_prepend_embed = nn.Sequential(
69
- nn.Linear(prepend_cond_dim, embed_dim, bias=False),
70
- nn.SiLU(),
71
- nn.Linear(embed_dim, embed_dim, bias=False)
72
- )
73
-
74
- self.input_concat_dim = input_concat_dim
75
-
76
- dim_in = io_channels + self.input_concat_dim
77
-
78
- self.patch_size = patch_size
79
-
80
- # Transformer
81
-
82
- self.transformer_type = transformer_type
83
-
84
- self.global_cond_type = global_cond_type
85
-
86
- if self.transformer_type == "x-transformers":
87
- self.transformer = ContinuousTransformerWrapper(
88
- dim_in=dim_in * patch_size,
89
- dim_out=io_channels * patch_size,
90
- max_seq_len=0, #Not relevant without absolute positional embeds
91
- attn_layers = Encoder(
92
- dim=embed_dim,
93
- depth=depth,
94
- heads=num_heads,
95
- attn_flash = True,
96
- cross_attend = cond_token_dim > 0,
97
- dim_context=None if cond_embed_dim == 0 else cond_embed_dim,
98
- zero_init_branch_output=True,
99
- use_abs_pos_emb = False,
100
- rotary_pos_emb=True,
101
- ff_swish = True,
102
- ff_glu = True,
103
- **kwargs
104
- )
105
- )
106
-
107
- elif self.transformer_type == "continuous_transformer":
108
-
109
- global_dim = None
110
-
111
- if self.global_cond_type == "adaLN":
112
- # The global conditioning is projected to the embed_dim already at this point
113
- global_dim = embed_dim
114
-
115
- self.transformer = ContinuousTransformer(
116
- dim=embed_dim,
117
- depth=depth,
118
- dim_heads=embed_dim // num_heads,
119
- dim_in=dim_in * patch_size,
120
- dim_out=io_channels * patch_size,
121
- cross_attend = cond_token_dim > 0,
122
- cond_token_dim = cond_embed_dim,
123
- global_cond_dim=global_dim,
124
- **kwargs
125
- )
126
-
127
- else:
128
- raise ValueError(f"Unknown transformer type: {self.transformer_type}")
129
-
130
- self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False)
131
- nn.init.zeros_(self.preprocess_conv.weight)
132
- self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False)
133
- nn.init.zeros_(self.postprocess_conv.weight)
134
-
135
- def _forward(
136
- self,
137
- x,
138
- t,
139
- mask=None,
140
- cross_attn_cond=None,
141
- cross_attn_cond_mask=None,
142
- input_concat_cond=None,
143
- global_embed=None,
144
- prepend_cond=None,
145
- prepend_cond_mask=None,
146
- return_info=False,
147
- **kwargs):
148
-
149
- if cross_attn_cond is not None:
150
- cross_attn_cond = self.to_cond_embed(cross_attn_cond)
151
-
152
- if global_embed is not None:
153
- # Project the global conditioning to the embedding dimension
154
- global_embed = self.to_global_embed(global_embed)
155
-
156
- prepend_inputs = None
157
- prepend_mask = None
158
- prepend_length = 0
159
- if prepend_cond is not None:
160
- # Project the prepend conditioning to the embedding dimension
161
- prepend_cond = self.to_prepend_embed(prepend_cond)
162
-
163
- prepend_inputs = prepend_cond
164
- if prepend_cond_mask is not None:
165
- prepend_mask = prepend_cond_mask
166
-
167
- if input_concat_cond is not None:
168
-
169
- # Interpolate input_concat_cond to the same length as x
170
- if input_concat_cond.shape[2] != x.shape[2]:
171
- input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
172
-
173
- x = torch.cat([x, input_concat_cond], dim=1)
174
-
175
- # Get the batch of timestep embeddings
176
- timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim)
177
-
178
- # Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
179
- if global_embed is not None:
180
- global_embed = global_embed + timestep_embed
181
- else:
182
- global_embed = timestep_embed
183
-
184
- # Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
185
- if self.global_cond_type == "prepend":
186
- if prepend_inputs is None:
187
- # Prepend inputs are just the global embed, and the mask is all ones
188
- prepend_inputs = global_embed.unsqueeze(1)
189
- prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
190
- else:
191
- # Prepend inputs are the prepend conditioning + the global embed
192
- prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
193
- prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1)
194
-
195
- prepend_length = prepend_inputs.shape[1]
196
-
197
- x = self.preprocess_conv(x) + x
198
-
199
- x = rearrange(x, "b c t -> b t c")
200
-
201
- extra_args = {}
202
-
203
- if self.global_cond_type == "adaLN":
204
- extra_args["global_cond"] = global_embed
205
-
206
- if self.patch_size > 1:
207
- x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
208
-
209
- if self.transformer_type == "x-transformers":
210
- output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, **extra_args, **kwargs)
211
- elif self.transformer_type == "continuous_transformer":
212
- output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs)
213
-
214
- if return_info:
215
- output, info = output
216
- elif self.transformer_type == "mm_transformer":
217
- output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask, **extra_args, **kwargs)
218
-
219
- output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:]
220
-
221
- if self.patch_size > 1:
222
- output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
223
-
224
- output = self.postprocess_conv(output) + output
225
-
226
- if return_info:
227
- return output, info
228
-
229
- return output
230
-
231
- def forward(
232
- self,
233
- x,
234
- t,
235
- cross_attn_cond=None,
236
- cross_attn_cond_mask=None,
237
- negative_cross_attn_cond=None,
238
- negative_cross_attn_mask=None,
239
- input_concat_cond=None,
240
- global_embed=None,
241
- negative_global_embed=None,
242
- prepend_cond=None,
243
- prepend_cond_mask=None,
244
- cfg_scale=1.0,
245
- cfg_dropout_prob=0.0,
246
- causal=False,
247
- scale_phi=0.0,
248
- mask=None,
249
- return_info=False,
250
- **kwargs):
251
-
252
- assert causal == False, "Causal mode is not supported for DiffusionTransformer"
253
-
254
- if cross_attn_cond_mask is not None:
255
- cross_attn_cond_mask = cross_attn_cond_mask.bool()
256
-
257
- cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention
258
-
259
- if prepend_cond_mask is not None:
260
- prepend_cond_mask = prepend_cond_mask.bool()
261
-
262
- # CFG dropout
263
- if cfg_dropout_prob > 0.0:
264
- if cross_attn_cond is not None:
265
- null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
266
- dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool)
267
- cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond)
268
-
269
- if prepend_cond is not None:
270
- null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
271
- dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool)
272
- prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond)
273
-
274
-
275
- if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None):
276
- # Classifier-free guidance
277
- # Concatenate conditioned and unconditioned inputs on the batch dimension
278
- batch_inputs = torch.cat([x, x], dim=0)
279
- batch_timestep = torch.cat([t, t], dim=0)
280
-
281
- if global_embed is not None:
282
- batch_global_cond = torch.cat([global_embed, global_embed], dim=0)
283
- else:
284
- batch_global_cond = None
285
-
286
- if input_concat_cond is not None:
287
- batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0)
288
- else:
289
- batch_input_concat_cond = None
290
-
291
- batch_cond = None
292
- batch_cond_masks = None
293
-
294
- # Handle CFG for cross-attention conditioning
295
- if cross_attn_cond is not None:
296
-
297
- null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
298
-
299
- # For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning
300
- if negative_cross_attn_cond is not None:
301
-
302
- # If there's a negative cross-attention mask, set the masked tokens to the null embed
303
- if negative_cross_attn_mask is not None:
304
- negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2)
305
-
306
- negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, null_embed)
307
-
308
- batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0)
309
-
310
- else:
311
- batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0)
312
-
313
- if cross_attn_cond_mask is not None:
314
- batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0)
315
-
316
- batch_prepend_cond = None
317
- batch_prepend_cond_mask = None
318
-
319
- if prepend_cond is not None:
320
-
321
- null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
322
-
323
- batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0)
324
-
325
- if prepend_cond_mask is not None:
326
- batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0)
327
-
328
-
329
- if mask is not None:
330
- batch_masks = torch.cat([mask, mask], dim=0)
331
- else:
332
- batch_masks = None
333
-
334
- batch_output = self._forward(
335
- batch_inputs,
336
- batch_timestep,
337
- cross_attn_cond=batch_cond,
338
- cross_attn_cond_mask=batch_cond_masks,
339
- mask = batch_masks,
340
- input_concat_cond=batch_input_concat_cond,
341
- global_embed = batch_global_cond,
342
- prepend_cond = batch_prepend_cond,
343
- prepend_cond_mask = batch_prepend_cond_mask,
344
- return_info = return_info,
345
- **kwargs)
346
-
347
- if return_info:
348
- batch_output, info = batch_output
349
-
350
- cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0)
351
- cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale
352
-
353
- # CFG Rescale
354
- if scale_phi != 0.0:
355
- cond_out_std = cond_output.std(dim=1, keepdim=True)
356
- out_cfg_std = cfg_output.std(dim=1, keepdim=True)
357
- output = scale_phi * (cfg_output * (cond_out_std/out_cfg_std)) + (1-scale_phi) * cfg_output
358
- else:
359
- output = cfg_output
360
-
361
- if return_info:
362
- return output, info
363
-
364
- return output
365
-
366
- else:
367
- return self._forward(
368
- x,
369
- t,
370
- cross_attn_cond=cross_attn_cond,
371
- cross_attn_cond_mask=cross_attn_cond_mask,
372
- input_concat_cond=input_concat_cond,
373
- global_embed=global_embed,
374
- prepend_cond=prepend_cond,
375
- prepend_cond_mask=prepend_cond_mask,
376
- mask=mask,
377
- return_info=return_info,
378
- **kwargs
379
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/models/factory.py DELETED
@@ -1,153 +0,0 @@
1
- import json
2
-
3
- def create_model_from_config(model_config):
4
- model_type = model_config.get('model_type', None)
5
-
6
- assert model_type is not None, 'model_type must be specified in model config'
7
-
8
- if model_type == 'autoencoder':
9
- from .autoencoders import create_autoencoder_from_config
10
- return create_autoencoder_from_config(model_config)
11
- elif model_type == 'diffusion_uncond':
12
- from .diffusion import create_diffusion_uncond_from_config
13
- return create_diffusion_uncond_from_config(model_config)
14
- elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior":
15
- from .diffusion import create_diffusion_cond_from_config
16
- return create_diffusion_cond_from_config(model_config)
17
- elif model_type == 'diffusion_autoencoder':
18
- from .autoencoders import create_diffAE_from_config
19
- return create_diffAE_from_config(model_config)
20
- elif model_type == 'lm':
21
- from .lm import create_audio_lm_from_config
22
- return create_audio_lm_from_config(model_config)
23
- else:
24
- raise NotImplementedError(f'Unknown model type: {model_type}')
25
-
26
- def create_model_from_config_path(model_config_path):
27
- with open(model_config_path) as f:
28
- model_config = json.load(f)
29
-
30
- return create_model_from_config(model_config)
31
-
32
- def create_pretransform_from_config(pretransform_config, sample_rate):
33
- pretransform_type = pretransform_config.get('type', None)
34
-
35
- assert pretransform_type is not None, 'type must be specified in pretransform config'
36
-
37
- if pretransform_type == 'autoencoder':
38
- from .autoencoders import create_autoencoder_from_config
39
- from .pretransforms import AutoencoderPretransform
40
-
41
- # Create fake top-level config to pass sample rate to autoencoder constructor
42
- # This is a bit of a hack but it keeps us from re-defining the sample rate in the config
43
- autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]}
44
- autoencoder = create_autoencoder_from_config(autoencoder_config)
45
-
46
- scale = pretransform_config.get("scale", 1.0)
47
- model_half = pretransform_config.get("model_half", False)
48
- iterate_batch = pretransform_config.get("iterate_batch", False)
49
- chunked = pretransform_config.get("chunked", False)
50
-
51
- pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked)
52
- elif pretransform_type == 'wavelet':
53
- from .pretransforms import WaveletPretransform
54
-
55
- wavelet_config = pretransform_config["config"]
56
- channels = wavelet_config["channels"]
57
- levels = wavelet_config["levels"]
58
- wavelet = wavelet_config["wavelet"]
59
-
60
- pretransform = WaveletPretransform(channels, levels, wavelet)
61
- elif pretransform_type == 'pqmf':
62
- from .pretransforms import PQMFPretransform
63
- pqmf_config = pretransform_config["config"]
64
- pretransform = PQMFPretransform(**pqmf_config)
65
- elif pretransform_type == 'dac_pretrained':
66
- from .pretransforms import PretrainedDACPretransform
67
- pretrained_dac_config = pretransform_config["config"]
68
- pretransform = PretrainedDACPretransform(**pretrained_dac_config)
69
- elif pretransform_type == "audiocraft_pretrained":
70
- from .pretransforms import AudiocraftCompressionPretransform
71
-
72
- audiocraft_config = pretransform_config["config"]
73
- pretransform = AudiocraftCompressionPretransform(**audiocraft_config)
74
- else:
75
- raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}')
76
-
77
- enable_grad = pretransform_config.get('enable_grad', False)
78
- pretransform.enable_grad = enable_grad
79
-
80
- pretransform.eval().requires_grad_(pretransform.enable_grad)
81
-
82
- return pretransform
83
-
84
- def create_bottleneck_from_config(bottleneck_config):
85
- bottleneck_type = bottleneck_config.get('type', None)
86
-
87
- assert bottleneck_type is not None, 'type must be specified in bottleneck config'
88
-
89
- if bottleneck_type == 'tanh':
90
- from .bottleneck import TanhBottleneck
91
- bottleneck = TanhBottleneck()
92
- elif bottleneck_type == 'vae':
93
- from .bottleneck import VAEBottleneck
94
- bottleneck = VAEBottleneck()
95
- elif bottleneck_type == 'rvq':
96
- from .bottleneck import RVQBottleneck
97
-
98
- quantizer_params = {
99
- "dim": 128,
100
- "codebook_size": 1024,
101
- "num_quantizers": 8,
102
- "decay": 0.99,
103
- "kmeans_init": True,
104
- "kmeans_iters": 50,
105
- "threshold_ema_dead_code": 2,
106
- }
107
-
108
- quantizer_params.update(bottleneck_config["config"])
109
-
110
- bottleneck = RVQBottleneck(**quantizer_params)
111
- elif bottleneck_type == "dac_rvq":
112
- from .bottleneck import DACRVQBottleneck
113
-
114
- bottleneck = DACRVQBottleneck(**bottleneck_config["config"])
115
-
116
- elif bottleneck_type == 'rvq_vae':
117
- from .bottleneck import RVQVAEBottleneck
118
-
119
- quantizer_params = {
120
- "dim": 128,
121
- "codebook_size": 1024,
122
- "num_quantizers": 8,
123
- "decay": 0.99,
124
- "kmeans_init": True,
125
- "kmeans_iters": 50,
126
- "threshold_ema_dead_code": 2,
127
- }
128
-
129
- quantizer_params.update(bottleneck_config["config"])
130
-
131
- bottleneck = RVQVAEBottleneck(**quantizer_params)
132
-
133
- elif bottleneck_type == 'dac_rvq_vae':
134
- from .bottleneck import DACRVQVAEBottleneck
135
- bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"])
136
- elif bottleneck_type == 'l2_norm':
137
- from .bottleneck import L2Bottleneck
138
- bottleneck = L2Bottleneck()
139
- elif bottleneck_type == "wasserstein":
140
- from .bottleneck import WassersteinBottleneck
141
- bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {}))
142
- elif bottleneck_type == "fsq":
143
- from .bottleneck import FSQBottleneck
144
- bottleneck = FSQBottleneck(**bottleneck_config["config"])
145
- else:
146
- raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}')
147
-
148
- requires_grad = bottleneck_config.get('requires_grad', True)
149
- if not requires_grad:
150
- for param in bottleneck.parameters():
151
- param.requires_grad = False
152
-
153
- return bottleneck
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/models/lm.py DELETED
@@ -1,541 +0,0 @@
1
- from dataclasses import dataclass
2
- import torch
3
- from tqdm.auto import trange
4
- import typing as tp
5
- from einops import rearrange
6
- from torch import nn
7
-
8
- from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config
9
- from .factory import create_pretransform_from_config
10
- from .lm_backbone import AudioLMBackbone, XTransformersAudioLMBackbone, ContinuousTransformerAudioLMBackbone
11
- from .pretransforms import Pretransform, AutoencoderPretransform, PretrainedDACPretransform, AudiocraftCompressionPretransform
12
- from .utils import multinomial, sample_top_k, sample_top_p
13
-
14
- from .codebook_patterns import (
15
- CodebooksPatternProvider,
16
- DelayedPatternProvider,
17
- MusicLMPattern,
18
- ParallelPatternProvider,
19
- UnrolledPatternProvider
20
- )
21
-
22
- # Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/models/lm.py under MIT license
23
- # License can be found in LICENSES/LICENSE_META.txt
24
-
25
- @dataclass
26
- class LMOutput:
27
- # The logits are already re-aligned with the input codes
28
- # hence no extra shift is required, e.g. when computing CE
29
- logits: torch.Tensor # [B, K, T, card]
30
- mask: torch.Tensor # [B, K, T]
31
-
32
- # Wrapper for a multi-codebook language model
33
- # Handles patterns and quantizer heads
34
- class AudioLanguageModel(nn.Module):
35
- def __init__(
36
- self,
37
- pattern_provider: CodebooksPatternProvider,
38
- backbone: AudioLMBackbone,
39
- num_quantizers: int,
40
- codebook_size: int
41
- ):
42
- super().__init__()
43
-
44
- self.pattern_provider = pattern_provider
45
- self.backbone = backbone
46
- self.num_quantizers = num_quantizers
47
- self.codebook_size = codebook_size
48
-
49
- self.masked_token_id = codebook_size
50
-
51
- # Per-quantizer embedders
52
- # Add one for the mask embed
53
- self.embeds = nn.ModuleList([nn.Embedding(codebook_size + 1, backbone.embed_dim) for _ in range(num_quantizers)])
54
-
55
- # Per-quantizer output heads
56
- self.quantizer_heads = nn.ModuleList([
57
- nn.Linear(backbone.embed_dim, codebook_size) for _ in range(num_quantizers)
58
- ])
59
-
60
- def forward(self,
61
- sequence: torch.Tensor, #[batch, seq_len,
62
- prepend_cond=None, #[batch, seq, channels]
63
- prepend_cond_mask=None,
64
- cross_attn_cond=None, #[batch, seq, channels],
65
- **kwargs
66
- ):
67
-
68
- batch, num_quantizers, seq_len = sequence.shape
69
-
70
- assert num_quantizers == self.num_quantizers, "Number of quantizers in sequence must match number of quantizers in model"
71
-
72
- backbone_input = sum([self.embeds[i](sequence[:, i]) for i in range(num_quantizers)]) # [batch, seq_len, embed_dim]
73
-
74
- dtype = next(self.parameters()).dtype
75
-
76
- if cross_attn_cond is not None:
77
- cross_attn_cond = cross_attn_cond.to(dtype)
78
-
79
- if prepend_cond is not None:
80
- prepend_cond = prepend_cond.to(dtype)
81
-
82
- if prepend_cond_mask is not None:
83
- prepend_cond_mask = prepend_cond_mask.to(dtype)
84
-
85
- backbone_input = backbone_input.to(dtype)
86
-
87
- output = self.backbone(
88
- backbone_input,
89
- cross_attn_cond=cross_attn_cond,
90
- prepend_cond=prepend_cond,
91
- prepend_cond_mask=prepend_cond_mask,
92
- **kwargs
93
- ) # [batch, seq_len, embed_dim]
94
-
95
- # Run output through quantizer heads
96
- logits = torch.stack([self.quantizer_heads[i](output) for i in range(num_quantizers)], dim=1) # [batch, num_quantizers, seq_len, codebook_size]
97
-
98
- return logits
99
-
100
- def compute_logits(
101
- self,
102
- codes, #[batch, num_quantizers, seq_len]
103
- **kwargs):
104
- """
105
- Compute logits for a batch of codes, optionally conditioning on cross-attention and prepend conditioning
106
- Handles translation between input sequence and pattern-shifted sequence
107
- Only used during training
108
- """
109
-
110
- batch, _, seq_len = codes.shape
111
-
112
- pattern = self.pattern_provider.get_pattern(seq_len)
113
-
114
- # Apply the token pattern to the codes, shifting the codes as needed and masking out invalid steps
115
- shifted_codes, _, _ = pattern.build_pattern_sequence(
116
- codes,
117
- self.masked_token_id,
118
- keep_only_valid_steps=True
119
- )
120
-
121
- # Run the model to get logits for each quantizer [batch, num_quantizers, seq_len, codebook_size]
122
- logits = self(shifted_codes, **kwargs)
123
-
124
- # Rearrange logits to prepare to revert pattern
125
- logits = rearrange(logits, "b n s c -> b c n s")
126
-
127
- # Revert sequence logits back to original sequence length, removing masked steps
128
- logits, _, logits_mask = pattern.revert_pattern_logits(
129
- logits, float('nan'), keep_only_valid_steps=True
130
- )
131
-
132
- logits = rearrange(logits, "b c n t -> b n t c")
133
-
134
- logits_mask = logits_mask[None, :, :].expand(batch, -1, -1) # [batch, num_quantizers, seq_len]
135
-
136
- return LMOutput(logits=logits, mask=logits_mask)
137
-
138
- # Conditioning and generation wrapper for a multi-codebook language model
139
- # Handles conditioning, CFG, generation, and encoding/decoding
140
- class AudioLanguageModelWrapper(nn.Module):
141
- def __init__(
142
- self,
143
- pretransform: Pretransform,
144
- lm: AudioLanguageModel,
145
- sample_rate: int,
146
- min_input_length: int,
147
- conditioner: MultiConditioner = None,
148
- cross_attn_cond_ids: tp.List[str] = [],
149
- prepend_cond_ids: tp.List[str] = [],
150
- global_cond_ids: tp.List[str] = []
151
- ):
152
- super().__init__()
153
-
154
- assert pretransform.is_discrete, "Pretransform must be discrete"
155
- self.pretransform = pretransform
156
-
157
- self.pretransform.requires_grad_(False)
158
- self.pretransform.eval()
159
-
160
- if isinstance(self.pretransform, AutoencoderPretransform):
161
- self.num_quantizers = self.pretransform.model.bottleneck.num_quantizers
162
- self.codebook_size = self.pretransform.model.bottleneck.codebook_size
163
- elif isinstance(self.pretransform, PretrainedDACPretransform):
164
- self.num_quantizers = self.pretransform.model.num_quantizers
165
- self.codebook_size = self.pretransform.model.codebook_size
166
- elif isinstance(self.pretransform, AudiocraftCompressionPretransform):
167
- self.num_quantizers = self.pretransform.num_quantizers
168
- self.codebook_size = self.pretransform.codebook_size
169
- else:
170
- raise NotImplementedError(f"Unrecognized pretransform type {type(self.pretransform)}")
171
-
172
- self.conditioner = conditioner
173
-
174
- self.lm = lm
175
-
176
- self.sample_rate = sample_rate
177
- self.min_input_length = min_input_length
178
-
179
- self.cross_attn_cond_ids = cross_attn_cond_ids
180
- self.prepend_cond_ids = prepend_cond_ids
181
- self.global_cond_ids = global_cond_ids
182
-
183
- def get_conditioning_inputs(self, cond: tp.Dict[str, tp.Any], negative=False):
184
- cross_attention_input = None
185
- prepend_cond = None
186
- prepend_cond_mask = None
187
- global_cond = None
188
-
189
- if len(self.cross_attn_cond_ids) > 0:
190
- # Concatenate all cross-attention inputs over the sequence dimension
191
- # Assumes that the cross-attention inputs are of shape (batch, seq, channels)
192
- cross_attention_input = torch.cat([cond[key][0] for key in self.cross_attn_cond_ids], dim=1)
193
-
194
- if len(self.prepend_cond_ids) > 0:
195
- # Concatenate all prepend conditioning inputs over the sequence dimension
196
- # Assumes that the prepend conditioning inputs are of shape (batch, seq, channels)
197
- prepend_cond = torch.cat([cond[key][0] for key in self.prepend_cond_ids], dim=1)
198
- prepend_cond_mask = torch.cat([cond[key][1] for key in self.prepend_cond_ids], dim=1)
199
-
200
- if len(self.global_cond_ids) > 0:
201
- # Concatenate all global conditioning inputs over the channel dimension
202
- # Assumes that the global conditioning inputs are of shape (batch, channels)
203
- global_cond = torch.cat([cond[key][0] for key in self.global_cond_ids], dim=-1)
204
- if len(global_cond.shape) == 3:
205
- global_cond = global_cond.squeeze(1)
206
-
207
- if negative:
208
- return {
209
- "negative_cross_attn_cond": cross_attention_input,
210
- "negative_prepend_cond": prepend_cond,
211
- "negative_prepend_cond_mask": prepend_cond_mask,
212
- "negative_global_cond": global_cond
213
- }
214
- else:
215
- return {
216
- "cross_attn_cond": cross_attention_input,
217
- "prepend_cond": prepend_cond,
218
- "prepend_cond_mask": prepend_cond_mask,
219
- "global_cond": global_cond
220
- }
221
-
222
- def compute_logits(
223
- self,
224
- codes,
225
- condition_tensors=None,
226
- cfg_dropout_prob=0.0,
227
- **kwargs
228
- ):
229
- """
230
- Compute logits for a batch of codes, and translates from conditioning inputs to model inputs
231
- Handles CFG dropout
232
- """
233
-
234
- if condition_tensors is None:
235
- condition_tensors = {}
236
-
237
- conditioning_inputs = self.get_conditioning_inputs(condition_tensors)
238
-
239
- cross_attn_cond = conditioning_inputs["cross_attn_cond"]
240
- prepend_cond = conditioning_inputs["prepend_cond"]
241
- prepend_cond_mask = conditioning_inputs["prepend_cond_mask"]
242
- global_cond = conditioning_inputs["global_cond"]
243
-
244
- if cfg_dropout_prob > 0.0:
245
- if cross_attn_cond is not None:
246
- null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
247
- dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool)
248
- cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond)
249
-
250
- if prepend_cond is not None:
251
- null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
252
- dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool)
253
- prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond)
254
-
255
- if global_cond is not None:
256
- null_embed = torch.zeros_like(global_cond, device=global_cond.device)
257
- dropout_mask = torch.bernoulli(torch.full((global_cond.shape[0], 1), cfg_dropout_prob, device=global_cond.device)).to(torch.bool)
258
- global_cond = torch.where(dropout_mask, null_embed, global_cond)
259
-
260
- return self.lm.compute_logits(codes, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs)
261
-
262
- def _sample_next_token(
263
- self,
264
- sequence, #[batch, num_quantizers, seq_len]
265
- conditioning_tensors=None,
266
- cross_attn_use_cfg=True,
267
- prepend_use_cfg=True,
268
- global_use_cfg=True,
269
- cfg_scale=1.0,
270
- top_k=250,
271
- top_p=0.0,
272
- temp=1.0,
273
- **kwargs
274
- ):
275
- """
276
- Sample the next token for a batch of codes, and translates from conditioning inputs to model inputs
277
- Handles CFG inference
278
- """
279
-
280
- if conditioning_tensors is None:
281
- conditioning_tensors = {}
282
-
283
- conditioning_inputs = self.get_conditioning_inputs(conditioning_tensors)
284
-
285
- cross_attn_cond = conditioning_inputs["cross_attn_cond"]
286
- prepend_cond = conditioning_inputs["prepend_cond"]
287
- prepend_cond_mask = conditioning_inputs["prepend_cond_mask"]
288
- global_cond = conditioning_inputs["global_cond"]
289
-
290
- if cfg_scale != 1.0:
291
-
292
- # Batch size is doubled to account for negative samples
293
- sequence = torch.cat([sequence, sequence], dim=0)
294
-
295
- if cross_attn_cond is not None and cross_attn_use_cfg:
296
- null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
297
-
298
- cross_attn_cond = torch.cat([cross_attn_cond, null_embed], dim=0)
299
-
300
- if prepend_cond is not None and prepend_use_cfg:
301
- null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
302
-
303
- prepend_cond = torch.cat([prepend_cond, null_embed], dim=0)
304
-
305
- if prepend_cond_mask is not None:
306
- prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0)
307
-
308
- if global_cond is not None and global_use_cfg:
309
- null_embed = torch.zeros_like(global_cond, device=global_cond.device)
310
-
311
- global_cond = torch.cat([global_cond, null_embed], dim=0)
312
-
313
- logits = self.lm(sequence, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs)
314
-
315
- if cfg_scale != 1.0:
316
- cond_logits, uncond_logits = logits.chunk(2, dim=0)
317
-
318
- logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
319
-
320
- logits = rearrange(logits, "b n s c -> b n c s") # [batch, num_quantizers, codebook_size, seq_len]
321
-
322
- # Grab the logits for the last step
323
- logits = logits[:, :, :, -1] # [batch, num_quantizers, codebook_size]
324
-
325
- # Apply top-k or top-p sampling
326
-
327
- if temp > 0:
328
- probs = torch.softmax(logits / temp, dim=-1)
329
-
330
- if top_p > 0.0:
331
- next_token = sample_top_p(probs, p=top_p)
332
- elif top_k > 0:
333
- next_token = sample_top_k(probs, k=top_k)
334
- else:
335
- next_token = multinomial(probs, num_samples=1)
336
-
337
- else:
338
- next_token = torch.argmax(logits, dim=-1, keepdim=True) # [batch, num_quantizers, 1]
339
-
340
- return next_token
341
-
342
- @torch.no_grad()
343
- def generate(
344
- self,
345
- max_gen_len: int = 256,
346
- batch_size: tp.Optional[int] = None,
347
- init_data: tp.Optional[torch.Tensor] = None,
348
- conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None,
349
- conditioning_tensors: tp.Optional[tp.Dict[str, tp.Any]] = None,
350
- callback: tp.Optional[tp.Callable[[int, int], None]] = None,
351
- use_cache: bool = True,
352
- cfg_scale: float = 1.0,
353
- **kwargs
354
- ):
355
- device = next(self.parameters()).device
356
-
357
- if conditioning_tensors is None and conditioning is not None:
358
- # Convert conditioning inputs to conditioning tensors
359
- conditioning_tensors = self.conditioner(conditioning, device)
360
-
361
- # Check that batch size is consistent across inputs
362
- possible_batch_sizes = []
363
-
364
- if batch_size is not None:
365
- possible_batch_sizes.append(batch_size)
366
- elif init_data is not None:
367
- possible_batch_sizes.append(init_data.shape[0])
368
- elif conditioning_tensors is not None:
369
- # Assume that the first conditioning tensor has the batch dimension
370
- possible_batch_sizes.append(conditioning_tensors[list(conditioning_tensors.keys())[0]][0].shape[0])
371
- else:
372
- possible_batch_sizes.append(1)
373
-
374
- assert [x == possible_batch_sizes[0] for x in possible_batch_sizes], "Batch size must be consistent across inputs"
375
-
376
- batch_size = possible_batch_sizes[0]
377
-
378
- if init_data is None:
379
- # Initialize with zeros
380
- assert batch_size > 0
381
- init_data = torch.zeros((batch_size, self.num_quantizers, 0), device=device, dtype=torch.long)
382
-
383
- batch_size, num_quantizers, seq_len = init_data.shape
384
-
385
- start_offset = seq_len
386
- assert start_offset < max_gen_len, "init data longer than max gen length"
387
-
388
- pattern = self.lm.pattern_provider.get_pattern(max_gen_len)
389
-
390
- unknown_token = -1
391
-
392
- # Initialize the generated codes with the init data, padded with unknown tokens
393
- gen_codes = torch.full((batch_size, num_quantizers, max_gen_len), unknown_token, device=device, dtype=torch.long)
394
- gen_codes[:, :, :start_offset] = init_data # [batch, num_quantizers, max_gen_len]
395
-
396
- gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.lm.masked_token_id) # [batch, num_quantizers, gen_sequence_len]
397
-
398
- start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
399
- assert start_offset_sequence is not None
400
-
401
- # Generation
402
- prev_offset = 0
403
- gen_sequence_len = gen_sequence.shape[-1]
404
-
405
- # Reset generation cache
406
- if use_cache and self.lm.backbone.use_generation_cache:
407
- self.lm.backbone.reset_generation_cache(max_gen_len, batch_size if cfg_scale == 1.0 else batch_size * 2)
408
-
409
- for offset in trange(start_offset_sequence, gen_sequence_len):
410
-
411
- # Get the full sequence up to the current offset
412
- curr_sequence = gen_sequence[..., prev_offset:offset]
413
-
414
- next_token = self._sample_next_token(
415
- curr_sequence,
416
- conditioning_tensors=conditioning_tensors,
417
- use_cache=use_cache,
418
- cfg_scale=cfg_scale,
419
- **kwargs
420
- )
421
-
422
- valid_mask = mask[..., offset:offset+1].expand(batch_size, -1, -1)
423
- next_token[~valid_mask] = self.lm.masked_token_id
424
-
425
- # Update the generated sequence with the next token
426
- gen_sequence[..., offset:offset+1] = torch.where(
427
- gen_sequence[..., offset:offset+1] == unknown_token,
428
- next_token,
429
- gen_sequence[..., offset:offset+1]
430
- )
431
-
432
- if use_cache and self.lm.backbone.use_generation_cache:
433
- # Only update the offset if caching is being used
434
- prev_offset = offset
435
-
436
- self.lm.backbone.update_generation_cache(offset)
437
-
438
- if callback is not None:
439
- # Callback to report progress
440
- # Pass in the offset relative to the start of the sequence, and the length of the current sequence
441
- callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
442
-
443
- assert not (gen_sequence == unknown_token).any(), "Unknown tokens in generated sequence"
444
-
445
- out_codes, _, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
446
-
447
- # sanity checks over the returned codes and corresponding masks
448
- assert (out_codes[..., :max_gen_len] != unknown_token).all()
449
- assert (out_mask[..., :max_gen_len] == 1).all()
450
-
451
- #out_codes = out_codes[..., 0:max_gen_len]
452
-
453
- return out_codes
454
-
455
-
456
- def generate_audio(
457
- self,
458
- **kwargs
459
- ):
460
- """
461
- Generate audio from a batch of codes
462
- """
463
-
464
- codes = self.generate(**kwargs)
465
-
466
- audio = self.pretransform.decode_tokens(codes)
467
-
468
- return audio
469
-
470
-
471
- def create_audio_lm_from_config(config):
472
- model_config = config.get('model', None)
473
- assert model_config is not None, 'model config must be specified in config'
474
-
475
- sample_rate = config.get('sample_rate', None)
476
- assert sample_rate is not None, "Must specify sample_rate in config"
477
-
478
- lm_config = model_config.get('lm', None)
479
- assert lm_config is not None, 'lm config must be specified in model config'
480
-
481
- codebook_pattern = lm_config.get("codebook_pattern", "delay")
482
-
483
- pattern_providers = {
484
- 'parallel': ParallelPatternProvider,
485
- 'delay': DelayedPatternProvider,
486
- 'unroll': UnrolledPatternProvider,
487
- 'musiclm': MusicLMPattern,
488
- }
489
-
490
- pretransform_config = model_config.get("pretransform", None)
491
-
492
- pretransform = create_pretransform_from_config(pretransform_config, sample_rate)
493
-
494
- assert pretransform.is_discrete, "Pretransform must be discrete"
495
-
496
- min_input_length = pretransform.downsampling_ratio
497
-
498
- pattern_provider = pattern_providers[codebook_pattern](n_q=pretransform.num_quantizers)
499
-
500
- conditioning_config = model_config.get('conditioning', None)
501
-
502
- conditioner = None
503
- if conditioning_config is not None:
504
- conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config)
505
-
506
- cross_attn_cond_ids = lm_config.get('cross_attention_cond_ids', [])
507
- prepend_cond_ids = lm_config.get('prepend_cond_ids', [])
508
- global_cond_ids = lm_config.get('global_cond_ids', [])
509
-
510
- lm_type = lm_config.get("type", None)
511
- lm_model_config = lm_config.get("config", None)
512
-
513
- assert lm_type is not None, "Must specify lm type in lm config"
514
- assert lm_model_config is not None, "Must specify lm model config in lm config"
515
-
516
- if lm_type == "x-transformers":
517
- backbone = XTransformersAudioLMBackbone(**lm_model_config)
518
- elif lm_type == "continuous_transformer":
519
- backbone = ContinuousTransformerAudioLMBackbone(**lm_model_config)
520
- else:
521
- raise NotImplementedError(f"Unrecognized lm type {lm_type}")
522
-
523
- lm = AudioLanguageModel(
524
- pattern_provider=pattern_provider,
525
- backbone=backbone,
526
- num_quantizers=pretransform.num_quantizers,
527
- codebook_size=pretransform.codebook_size
528
- )
529
-
530
- model = AudioLanguageModelWrapper(
531
- pretransform=pretransform,
532
- lm=lm,
533
- conditioner=conditioner,
534
- sample_rate=sample_rate,
535
- min_input_length=min_input_length,
536
- cross_attn_cond_ids=cross_attn_cond_ids,
537
- prepend_cond_ids=prepend_cond_ids,
538
- global_cond_ids=global_cond_ids
539
- )
540
-
541
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/models/lm_backbone.py DELETED
@@ -1,159 +0,0 @@
1
- from torch import nn
2
- from x_transformers import ContinuousTransformerWrapper, Decoder
3
-
4
- from .transformer import ContinuousTransformer
5
-
6
- # Interface for backbone of a language model
7
- # Handles conditioning and cross-attention
8
- # Does not have to deal with patterns or quantizer heads
9
- class AudioLMBackbone(nn.Module):
10
- def __init__(self, embed_dim: int, use_generation_cache=False, **kwargs):
11
- super().__init__()
12
-
13
- self.embed_dim = embed_dim
14
- self.use_generation_cache = use_generation_cache
15
-
16
- def forward(
17
- self,
18
- x,
19
- cross_attn_cond=None,
20
- prepend_cond=None,
21
- prepend_cond_mask=None,
22
- global_cond=None,
23
- use_cache=False,
24
- **kwargs
25
- ):
26
- raise NotImplementedError
27
-
28
- def reset_generation_cache(
29
- self,
30
- max_seq_len,
31
- batch_size,
32
- dtype=None
33
- ):
34
- pass
35
-
36
- def update_generation_cache(
37
- self,
38
- seqlen_offset
39
- ):
40
- pass
41
-
42
- class XTransformersAudioLMBackbone(AudioLMBackbone):
43
- def __init__(self,
44
- embed_dim: int,
45
- cross_attn_cond_dim: int = 0,
46
- prepend_cond_dim: int = 0,
47
- **kwargs):
48
- super().__init__(embed_dim=embed_dim)
49
-
50
- # Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer
51
- self.model = ContinuousTransformerWrapper(
52
- dim_in=embed_dim,
53
- dim_out=embed_dim,
54
- max_seq_len=0, #Not relevant without absolute positional embeds,
55
- attn_layers=Decoder(
56
- dim=embed_dim,
57
- attn_flash = True,
58
- cross_attend = cross_attn_cond_dim > 0,
59
- zero_init_branch_output=True,
60
- use_abs_pos_emb = False,
61
- rotary_pos_emb=True,
62
- ff_swish = True,
63
- ff_glu = True,
64
- **kwargs
65
- )
66
- )
67
-
68
- if prepend_cond_dim > 0:
69
- # Prepend conditioning
70
- self.to_prepend_embed = nn.Sequential(
71
- nn.Linear(prepend_cond_dim, embed_dim, bias=False),
72
- nn.SiLU(),
73
- nn.Linear(embed_dim, embed_dim, bias=False)
74
- )
75
-
76
- if cross_attn_cond_dim > 0:
77
- # Cross-attention conditioning
78
- self.to_cross_attn_embed = nn.Sequential(
79
- nn.Linear(cross_attn_cond_dim, embed_dim, bias=False),
80
- nn.SiLU(),
81
- nn.Linear(embed_dim, embed_dim, bias=False)
82
- )
83
-
84
- def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross_attn_cond=None, global_cond=None, use_cache=False):
85
-
86
- prepend_length = 0
87
- if prepend_cond is not None:
88
- # Project the prepend conditioning to the embedding dimension
89
- prepend_cond = self.to_prepend_embed(prepend_cond)
90
- prepend_length = prepend_cond.shape[1]
91
-
92
- if prepend_cond_mask is not None:
93
- # Cast mask to bool
94
- prepend_cond_mask = prepend_cond_mask.bool()
95
-
96
- if cross_attn_cond is not None:
97
- # Project the cross-attention conditioning to the embedding dimension
98
- cross_attn_cond = self.to_cross_attn_embed(cross_attn_cond)
99
-
100
- return self.model(x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask)[:, prepend_length:, :]
101
-
102
- class ContinuousTransformerAudioLMBackbone(AudioLMBackbone):
103
- def __init__(self,
104
- embed_dim: int,
105
- cross_attn_cond_dim: int = 0,
106
- prepend_cond_dim: int = 0,
107
- project_cross_attn_cond: bool = False,
108
- **kwargs):
109
- super().__init__(embed_dim=embed_dim)
110
-
111
- # Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer
112
- self.model = ContinuousTransformer(
113
- dim=embed_dim,
114
- dim_in=embed_dim,
115
- dim_out=embed_dim,
116
- cross_attend = cross_attn_cond_dim > 0,
117
- cond_token_dim = embed_dim if project_cross_attn_cond else cross_attn_cond_dim,
118
- causal=True,
119
- **kwargs
120
- )
121
-
122
- if prepend_cond_dim > 0:
123
- # Prepend conditioning
124
- self.to_prepend_embed = nn.Sequential(
125
- nn.Linear(prepend_cond_dim, embed_dim, bias=False),
126
- nn.SiLU(),
127
- nn.Linear(embed_dim, embed_dim, bias=False)
128
- )
129
-
130
- if cross_attn_cond_dim > 0 and project_cross_attn_cond:
131
- # Cross-attention conditioning
132
- self.to_cross_attn_embed = nn.Sequential(
133
- nn.Linear(cross_attn_cond_dim, embed_dim, bias=False),
134
- nn.SiLU(),
135
- nn.Linear(embed_dim, embed_dim, bias=False)
136
- )
137
- else:
138
- self.to_cross_attn_embed = nn.Identity()
139
-
140
- def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross_attn_cond=None, global_cond=None, use_cache=False):
141
-
142
- prepend_length = 0
143
- if prepend_cond is not None:
144
- # Project the prepend conditioning to the embedding dimension
145
- prepend_cond = self.to_prepend_embed(prepend_cond)
146
- prepend_length = prepend_cond.shape[1]
147
-
148
- if prepend_cond_mask is not None:
149
- # Cast mask to bool
150
- prepend_cond_mask = prepend_cond_mask.bool()
151
-
152
- if cross_attn_cond is not None:
153
- # Cast cross_attn_cond to same dtype as self.to_cross_attn_embed
154
- cross_attn_cond = cross_attn_cond.to(self.to_cross_attn_embed[0].weight.dtype)
155
-
156
- # Project the cross-attention conditioning to the embedding dimension
157
- cross_attn_cond = self.to_cross_attn_embed(cross_attn_cond)
158
-
159
- return self.model(x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask)[:, prepend_length:, :]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/models/local_attention.py DELETED
@@ -1,278 +0,0 @@
1
- import torch
2
-
3
- from einops import rearrange
4
- from torch import nn
5
-
6
- from .blocks import AdaRMSNorm
7
- from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm
8
-
9
- def checkpoint(function, *args, **kwargs):
10
- kwargs.setdefault("use_reentrant", False)
11
- return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
12
-
13
- # Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py
14
- class ContinuousLocalTransformer(nn.Module):
15
- def __init__(
16
- self,
17
- *,
18
- dim,
19
- depth,
20
- dim_in = None,
21
- dim_out = None,
22
- causal = False,
23
- local_attn_window_size = 64,
24
- heads = 8,
25
- ff_mult = 2,
26
- cond_dim = 0,
27
- cross_attn_cond_dim = 0,
28
- **kwargs
29
- ):
30
- super().__init__()
31
-
32
- dim_head = dim//heads
33
-
34
- self.layers = nn.ModuleList([])
35
-
36
- self.project_in = nn.Linear(dim_in, dim) if dim_in is not None else nn.Identity()
37
-
38
- self.project_out = nn.Linear(dim, dim_out) if dim_out is not None else nn.Identity()
39
-
40
- self.local_attn_window_size = local_attn_window_size
41
-
42
- self.cond_dim = cond_dim
43
-
44
- self.cross_attn_cond_dim = cross_attn_cond_dim
45
-
46
- self.rotary_pos_emb = RotaryEmbedding(max(dim_head // 2, 32))
47
-
48
- for _ in range(depth):
49
-
50
- self.layers.append(nn.ModuleList([
51
- AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
52
- Attention(
53
- dim=dim,
54
- dim_heads=dim_head,
55
- causal=causal,
56
- zero_init_output=True,
57
- natten_kernel_size=local_attn_window_size,
58
- ),
59
- Attention(
60
- dim=dim,
61
- dim_heads=dim_head,
62
- dim_context = cross_attn_cond_dim,
63
- zero_init_output=True
64
- ) if self.cross_attn_cond_dim > 0 else nn.Identity(),
65
- AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
66
- FeedForward(dim = dim, mult = ff_mult, no_bias=True)
67
- ]))
68
-
69
- def forward(self, x, mask = None, cond = None, cross_attn_cond = None, cross_attn_cond_mask = None, prepend_cond = None):
70
-
71
- x = checkpoint(self.project_in, x)
72
-
73
- if prepend_cond is not None:
74
- x = torch.cat([prepend_cond, x], dim=1)
75
-
76
- pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
77
-
78
- for attn_norm, attn, xattn, ff_norm, ff in self.layers:
79
-
80
- residual = x
81
- if cond is not None:
82
- x = checkpoint(attn_norm, x, cond)
83
- else:
84
- x = checkpoint(attn_norm, x)
85
-
86
- x = checkpoint(attn, x, mask = mask, rotary_pos_emb=pos_emb) + residual
87
-
88
- if cross_attn_cond is not None:
89
- x = checkpoint(xattn, x, context=cross_attn_cond, context_mask=cross_attn_cond_mask) + x
90
-
91
- residual = x
92
-
93
- if cond is not None:
94
- x = checkpoint(ff_norm, x, cond)
95
- else:
96
- x = checkpoint(ff_norm, x)
97
-
98
- x = checkpoint(ff, x) + residual
99
-
100
- return checkpoint(self.project_out, x)
101
-
102
- class TransformerDownsampleBlock1D(nn.Module):
103
- def __init__(
104
- self,
105
- in_channels,
106
- embed_dim = 768,
107
- depth = 3,
108
- heads = 12,
109
- downsample_ratio = 2,
110
- local_attn_window_size = 64,
111
- **kwargs
112
- ):
113
- super().__init__()
114
-
115
- self.downsample_ratio = downsample_ratio
116
-
117
- self.transformer = ContinuousLocalTransformer(
118
- dim=embed_dim,
119
- depth=depth,
120
- heads=heads,
121
- local_attn_window_size=local_attn_window_size,
122
- **kwargs
123
- )
124
-
125
- self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
126
-
127
- self.project_down = nn.Linear(embed_dim * self.downsample_ratio, embed_dim, bias=False)
128
-
129
-
130
- def forward(self, x):
131
-
132
- x = checkpoint(self.project_in, x)
133
-
134
- # Compute
135
- x = self.transformer(x)
136
-
137
- # Trade sequence length for channels
138
- x = rearrange(x, "b (n r) c -> b n (c r)", r=self.downsample_ratio)
139
-
140
- # Project back to embed dim
141
- x = checkpoint(self.project_down, x)
142
-
143
- return x
144
-
145
- class TransformerUpsampleBlock1D(nn.Module):
146
- def __init__(
147
- self,
148
- in_channels,
149
- embed_dim,
150
- depth = 3,
151
- heads = 12,
152
- upsample_ratio = 2,
153
- local_attn_window_size = 64,
154
- **kwargs
155
- ):
156
- super().__init__()
157
-
158
- self.upsample_ratio = upsample_ratio
159
-
160
- self.transformer = ContinuousLocalTransformer(
161
- dim=embed_dim,
162
- depth=depth,
163
- heads=heads,
164
- local_attn_window_size = local_attn_window_size,
165
- **kwargs
166
- )
167
-
168
- self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
169
-
170
- self.project_up = nn.Linear(embed_dim, embed_dim * self.upsample_ratio, bias=False)
171
-
172
- def forward(self, x):
173
-
174
- # Project to embed dim
175
- x = checkpoint(self.project_in, x)
176
-
177
- # Project to increase channel dim
178
- x = checkpoint(self.project_up, x)
179
-
180
- # Trade channels for sequence length
181
- x = rearrange(x, "b n (c r) -> b (n r) c", r=self.upsample_ratio)
182
-
183
- # Compute
184
- x = self.transformer(x)
185
-
186
- return x
187
-
188
-
189
- class TransformerEncoder1D(nn.Module):
190
- def __init__(
191
- self,
192
- in_channels,
193
- out_channels,
194
- embed_dims = [96, 192, 384, 768],
195
- heads = [12, 12, 12, 12],
196
- depths = [3, 3, 3, 3],
197
- ratios = [2, 2, 2, 2],
198
- local_attn_window_size = 64,
199
- **kwargs
200
- ):
201
- super().__init__()
202
-
203
- layers = []
204
-
205
- for layer in range(len(depths)):
206
- prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
207
-
208
- layers.append(
209
- TransformerDownsampleBlock1D(
210
- in_channels = prev_dim,
211
- embed_dim = embed_dims[layer],
212
- heads = heads[layer],
213
- depth = depths[layer],
214
- downsample_ratio = ratios[layer],
215
- local_attn_window_size = local_attn_window_size,
216
- **kwargs
217
- )
218
- )
219
-
220
- self.layers = nn.Sequential(*layers)
221
-
222
- self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
223
- self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
224
-
225
- def forward(self, x):
226
- x = rearrange(x, "b c n -> b n c")
227
- x = checkpoint(self.project_in, x)
228
- x = self.layers(x)
229
- x = checkpoint(self.project_out, x)
230
- x = rearrange(x, "b n c -> b c n")
231
-
232
- return x
233
-
234
-
235
- class TransformerDecoder1D(nn.Module):
236
- def __init__(
237
- self,
238
- in_channels,
239
- out_channels,
240
- embed_dims = [768, 384, 192, 96],
241
- heads = [12, 12, 12, 12],
242
- depths = [3, 3, 3, 3],
243
- ratios = [2, 2, 2, 2],
244
- local_attn_window_size = 64,
245
- **kwargs
246
- ):
247
-
248
- super().__init__()
249
-
250
- layers = []
251
-
252
- for layer in range(len(depths)):
253
- prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
254
-
255
- layers.append(
256
- TransformerUpsampleBlock1D(
257
- in_channels = prev_dim,
258
- embed_dim = embed_dims[layer],
259
- heads = heads[layer],
260
- depth = depths[layer],
261
- upsample_ratio = ratios[layer],
262
- local_attn_window_size = local_attn_window_size,
263
- **kwargs
264
- )
265
- )
266
-
267
- self.layers = nn.Sequential(*layers)
268
-
269
- self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
270
- self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
271
-
272
- def forward(self, x):
273
- x = rearrange(x, "b c n -> b n c")
274
- x = checkpoint(self.project_in, x)
275
- x = self.layers(x)
276
- x = checkpoint(self.project_out, x)
277
- x = rearrange(x, "b n c -> b c n")
278
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/models/pqmf.py DELETED
@@ -1,393 +0,0 @@
1
- import math
2
- import numpy as np
3
- import torch
4
- import torch.nn as nn
5
- from einops import rearrange
6
- from scipy.optimize import fmin
7
- from scipy.signal import firwin, kaiser, kaiser_beta, kaiserord
8
-
9
- class PQMF(nn.Module):
10
- """
11
- Pseudo Quadrature Mirror Filter (PQMF) for multiband signal decomposition and reconstruction.
12
- Uses polyphase representation which is computationally more efficient for real-time.
13
-
14
- Parameters:
15
- - attenuation (int): Desired attenuation of the rejected frequency bands, usually between 80 and 120 dB.
16
- - num_bands (int): Number of desired frequency bands. It must be a power of 2.
17
- """
18
-
19
- def __init__(self, attenuation, num_bands):
20
- super(PQMF, self).__init__()
21
-
22
- # Ensure num_bands is a power of 2
23
- is_power_of_2 = (math.log2(num_bands) == int(math.log2(num_bands)))
24
- assert is_power_of_2, "'num_bands' must be a power of 2."
25
-
26
- # Create the prototype filter
27
- prototype_filter = design_prototype_filter(attenuation, num_bands)
28
- filter_bank = generate_modulated_filter_bank(prototype_filter, num_bands)
29
- padded_filter_bank = pad_to_nearest_power_of_two(filter_bank)
30
-
31
- # Register filters and settings
32
- self.register_buffer("filter_bank", padded_filter_bank)
33
- self.register_buffer("prototype", prototype_filter)
34
- self.num_bands = num_bands
35
-
36
- def forward(self, signal):
37
- """Decompose the signal into multiple frequency bands."""
38
- # If signal is not a pytorch tensor of Batch x Channels x Length, convert it
39
- signal = prepare_signal_dimensions(signal)
40
- # The signal length must be a multiple of num_bands. Pad it with zeros.
41
- signal = pad_signal(signal, self.num_bands)
42
- # run it
43
- signal = polyphase_analysis(signal, self.filter_bank)
44
- return apply_alias_cancellation(signal)
45
-
46
- def inverse(self, bands):
47
- """Reconstruct the original signal from the frequency bands."""
48
- bands = apply_alias_cancellation(bands)
49
- return polyphase_synthesis(bands, self.filter_bank)
50
-
51
-
52
- def prepare_signal_dimensions(signal):
53
- """
54
- Rearrange signal into Batch x Channels x Length.
55
-
56
- Parameters
57
- ----------
58
- signal : torch.Tensor or numpy.ndarray
59
- The input signal.
60
-
61
- Returns
62
- -------
63
- torch.Tensor
64
- Preprocessed signal tensor.
65
- """
66
- # Convert numpy to torch tensor
67
- if isinstance(signal, np.ndarray):
68
- signal = torch.from_numpy(signal)
69
-
70
- # Ensure tensor
71
- if not isinstance(signal, torch.Tensor):
72
- raise ValueError("Input should be either a numpy array or a PyTorch tensor.")
73
-
74
- # Modify dimension of signal to Batch x Channels x Length
75
- if signal.dim() == 1:
76
- # This is just a mono signal. Unsqueeze to 1 x 1 x Length
77
- signal = signal.unsqueeze(0).unsqueeze(0)
78
- elif signal.dim() == 2:
79
- # This is a multi-channel signal (e.g. stereo)
80
- # Rearrange so that larger dimension (Length) is last
81
- if signal.shape[0] > signal.shape[1]:
82
- signal = signal.T
83
- # Unsqueeze to 1 x Channels x Length
84
- signal = signal.unsqueeze(0)
85
- return signal
86
-
87
- def pad_signal(signal, num_bands):
88
- """
89
- Pads the signal to make its length divisible by the given number of bands.
90
-
91
- Parameters
92
- ----------
93
- signal : torch.Tensor
94
- The input signal tensor, where the last dimension represents the signal length.
95
-
96
- num_bands : int
97
- The number of bands by which the signal length should be divisible.
98
-
99
- Returns
100
- -------
101
- torch.Tensor
102
- The padded signal tensor. If the original signal length was already divisible
103
- by num_bands, returns the original signal unchanged.
104
- """
105
- remainder = signal.shape[-1] % num_bands
106
- if remainder > 0:
107
- padding_size = num_bands - remainder
108
- signal = nn.functional.pad(signal, (0, padding_size))
109
- return signal
110
-
111
- def generate_modulated_filter_bank(prototype_filter, num_bands):
112
- """
113
- Generate a QMF bank of cosine modulated filters based on a given prototype filter.
114
-
115
- Parameters
116
- ----------
117
- prototype_filter : torch.Tensor
118
- The prototype filter used as the basis for modulation.
119
- num_bands : int
120
- The number of desired subbands or filters.
121
-
122
- Returns
123
- -------
124
- torch.Tensor
125
- A bank of cosine modulated filters.
126
- """
127
-
128
- # Initialize indices for modulation.
129
- subband_indices = torch.arange(num_bands).reshape(-1, 1)
130
-
131
- # Calculate the length of the prototype filter.
132
- filter_length = prototype_filter.shape[-1]
133
-
134
- # Generate symmetric time indices centered around zero.
135
- time_indices = torch.arange(-(filter_length // 2), (filter_length // 2) + 1)
136
-
137
- # Calculate phase offsets to ensure orthogonality between subbands.
138
- phase_offsets = (-1)**subband_indices * np.pi / 4
139
-
140
- # Compute the cosine modulation function.
141
- modulation = torch.cos(
142
- (2 * subband_indices + 1) * np.pi / (2 * num_bands) * time_indices + phase_offsets
143
- )
144
-
145
- # Apply modulation to the prototype filter.
146
- modulated_filters = 2 * prototype_filter * modulation
147
-
148
- return modulated_filters
149
-
150
-
151
- def design_kaiser_lowpass(angular_cutoff, attenuation, filter_length=None):
152
- """
153
- Design a lowpass filter using the Kaiser window.
154
-
155
- Parameters
156
- ----------
157
- angular_cutoff : float
158
- The angular frequency cutoff of the filter.
159
- attenuation : float
160
- The desired stopband attenuation in decibels (dB).
161
- filter_length : int, optional
162
- Desired length of the filter. If not provided, it's computed based on the given specs.
163
-
164
- Returns
165
- -------
166
- ndarray
167
- The designed lowpass filter coefficients.
168
- """
169
-
170
- estimated_length, beta = kaiserord(attenuation, angular_cutoff / np.pi)
171
-
172
- # Ensure the estimated length is odd.
173
- estimated_length = 2 * (estimated_length // 2) + 1
174
-
175
- if filter_length is None:
176
- filter_length = estimated_length
177
-
178
- return firwin(filter_length, angular_cutoff, window=('kaiser', beta), scale=False, nyq=np.pi)
179
-
180
-
181
- def evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length):
182
- """
183
- Evaluate the filter's objective value based on the criteria from https://ieeexplore.ieee.org/document/681427
184
-
185
- Parameters
186
- ----------
187
- angular_cutoff : float
188
- Angular frequency cutoff of the filter.
189
- attenuation : float
190
- Desired stopband attenuation in dB.
191
- num_bands : int
192
- Number of bands for the multiband filter system.
193
- filter_length : int, optional
194
- Desired length of the filter.
195
-
196
- Returns
197
- -------
198
- float
199
- The computed objective (loss) value for the given filter specs.
200
- """
201
-
202
- filter_coeffs = design_kaiser_lowpass(angular_cutoff, attenuation, filter_length)
203
- convolved_filter = np.convolve(filter_coeffs, filter_coeffs[::-1], "full")
204
-
205
- return np.max(np.abs(convolved_filter[convolved_filter.shape[-1] // 2::2 * num_bands][1:]))
206
-
207
-
208
- def design_prototype_filter(attenuation, num_bands, filter_length=None):
209
- """
210
- Design the optimal prototype filter for a multiband system given the desired specs.
211
-
212
- Parameters
213
- ----------
214
- attenuation : float
215
- The desired stopband attenuation in dB.
216
- num_bands : int
217
- Number of bands for the multiband filter system.
218
- filter_length : int, optional
219
- Desired length of the filter. If not provided, it's computed based on the given specs.
220
-
221
- Returns
222
- -------
223
- ndarray
224
- The optimal prototype filter coefficients.
225
- """
226
-
227
- optimal_angular_cutoff = fmin(lambda angular_cutoff: evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length),
228
- 1 / num_bands, disp=0)[0]
229
-
230
- prototype_filter = design_kaiser_lowpass(optimal_angular_cutoff, attenuation, filter_length)
231
- return torch.tensor(prototype_filter, dtype=torch.float32)
232
-
233
- def pad_to_nearest_power_of_two(x):
234
- """
235
- Pads the input tensor 'x' on both sides such that its last dimension
236
- becomes the nearest larger power of two.
237
-
238
- Parameters:
239
- -----------
240
- x : torch.Tensor
241
- The input tensor to be padded.
242
-
243
- Returns:
244
- --------
245
- torch.Tensor
246
- The padded tensor.
247
- """
248
- current_length = x.shape[-1]
249
- target_length = 2**math.ceil(math.log2(current_length))
250
-
251
- total_padding = target_length - current_length
252
- left_padding = total_padding // 2
253
- right_padding = total_padding - left_padding
254
-
255
- return nn.functional.pad(x, (left_padding, right_padding))
256
-
257
- def apply_alias_cancellation(x):
258
- """
259
- Applies alias cancellation by inverting the sign of every
260
- second element of every second row, starting from the second
261
- row's first element in a tensor.
262
-
263
- This operation helps ensure that the aliasing introduced in
264
- each band during the decomposition will be counteracted during
265
- the reconstruction.
266
-
267
- Parameters:
268
- -----------
269
- x : torch.Tensor
270
- The input tensor.
271
-
272
- Returns:
273
- --------
274
- torch.Tensor
275
- Tensor with specific elements' sign inverted for alias cancellation.
276
- """
277
-
278
- # Create a mask of the same shape as 'x', initialized with all ones
279
- mask = torch.ones_like(x)
280
-
281
- # Update specific elements in the mask to -1 to perform inversion
282
- mask[..., 1::2, ::2] = -1
283
-
284
- # Apply the mask to the input tensor 'x'
285
- return x * mask
286
-
287
- def ensure_odd_length(tensor):
288
- """
289
- Pads the last dimension of a tensor to ensure its size is odd.
290
-
291
- Parameters:
292
- -----------
293
- tensor : torch.Tensor
294
- Input tensor whose last dimension might need padding.
295
-
296
- Returns:
297
- --------
298
- torch.Tensor
299
- The original tensor if its last dimension was already odd,
300
- or the padded tensor with an odd-sized last dimension.
301
- """
302
-
303
- last_dim_size = tensor.shape[-1]
304
-
305
- if last_dim_size % 2 == 0:
306
- tensor = nn.functional.pad(tensor, (0, 1))
307
-
308
- return tensor
309
-
310
- def polyphase_analysis(signal, filter_bank):
311
- """
312
- Applies the polyphase method to efficiently analyze the signal using a filter bank.
313
-
314
- Parameters:
315
- -----------
316
- signal : torch.Tensor
317
- Input signal tensor with shape (Batch x Channels x Length).
318
-
319
- filter_bank : torch.Tensor
320
- Filter bank tensor with shape (Bands x Length).
321
-
322
- Returns:
323
- --------
324
- torch.Tensor
325
- Signal split into sub-bands. (Batch x Channels x Bands x Length)
326
- """
327
-
328
- num_bands = filter_bank.shape[0]
329
- num_channels = signal.shape[1]
330
-
331
- # Rearrange signal for polyphase processing.
332
- # Also combine Batch x Channel into one dimension for now.
333
- #signal = rearrange(signal, "b c (t n) -> b (c n) t", n=num_bands)
334
- signal = rearrange(signal, "b c (t n) -> (b c) n t", n=num_bands)
335
-
336
- # Rearrange the filter bank for matching signal shape
337
- filter_bank = rearrange(filter_bank, "c (t n) -> c n t", n=num_bands)
338
-
339
- # Apply convolution with appropriate padding to maintain spatial dimensions
340
- padding = filter_bank.shape[-1] // 2
341
- filtered_signal = nn.functional.conv1d(signal, filter_bank, padding=padding)
342
-
343
- # Truncate the last dimension post-convolution to adjust the output shape
344
- filtered_signal = filtered_signal[..., :-1]
345
- # Rearrange the first dimension back into Batch x Channels
346
- filtered_signal = rearrange(filtered_signal, "(b c) n t -> b c n t", c=num_channels)
347
-
348
- return filtered_signal
349
-
350
- def polyphase_synthesis(signal, filter_bank):
351
- """
352
- Polyphase Inverse: Apply polyphase filter bank synthesis to reconstruct a signal.
353
-
354
- Parameters
355
- ----------
356
- signal : torch.Tensor
357
- Decomposed signal to be reconstructed (shape: Batch x Channels x Bands x Length).
358
-
359
- filter_bank : torch.Tensor
360
- Analysis filter bank (shape: Bands x Length).
361
-
362
- should_rearrange : bool, optional
363
- Flag to determine if the filters should be rearranged for polyphase synthesis. Default is True.
364
-
365
- Returns
366
- -------
367
- torch.Tensor
368
- Reconstructed signal (shape: Batch x Channels X Length)
369
- """
370
-
371
- num_bands = filter_bank.shape[0]
372
- num_channels = signal.shape[1]
373
-
374
- # Rearrange the filter bank
375
- filter_bank = filter_bank.flip(-1)
376
- filter_bank = rearrange(filter_bank, "c (t n) -> n c t", n=num_bands)
377
-
378
- # Combine Batch x Channels into one dimension for now.
379
- signal = rearrange(signal, "b c n t -> (b c) n t")
380
-
381
- # Apply convolution with appropriate padding
382
- padding_amount = filter_bank.shape[-1] // 2 + 1
383
- reconstructed_signal = nn.functional.conv1d(signal, filter_bank, padding=int(padding_amount))
384
-
385
- # Scale the result
386
- reconstructed_signal = reconstructed_signal[..., :-1] * num_bands
387
-
388
- # Reorganize the output and truncate
389
- reconstructed_signal = reconstructed_signal.flip(1)
390
- reconstructed_signal = rearrange(reconstructed_signal, "(b c) n t -> b c (t n)", c=num_channels, n=num_bands)
391
- reconstructed_signal = reconstructed_signal[..., 2 * filter_bank.shape[1]:]
392
-
393
- return reconstructed_signal
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/models/pretrained.py DELETED
@@ -1,25 +0,0 @@
1
- import json
2
-
3
- from .factory import create_model_from_config
4
- from .utils import load_ckpt_state_dict
5
-
6
- from huggingface_hub import hf_hub_download
7
-
8
- def get_pretrained_model(name: str):
9
-
10
- model_config_path = hf_hub_download(name, filename="model_config.json", repo_type='model')
11
-
12
- with open(model_config_path) as f:
13
- model_config = json.load(f)
14
-
15
- model = create_model_from_config(model_config)
16
-
17
- # Try to download the model.safetensors file first, if it doesn't exist, download the model.ckpt file
18
- try:
19
- model_ckpt_path = hf_hub_download(name, filename="model.safetensors", repo_type='model')
20
- except Exception as e:
21
- model_ckpt_path = hf_hub_download(name, filename="model.ckpt", repo_type='model')
22
-
23
- model.load_state_dict(load_ckpt_state_dict(model_ckpt_path))
24
-
25
- return model, model_config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/models/pretransforms.py DELETED
@@ -1,258 +0,0 @@
1
- import torch
2
- from einops import rearrange
3
- from torch import nn
4
-
5
- class Pretransform(nn.Module):
6
- def __init__(self, enable_grad, io_channels, is_discrete):
7
- super().__init__()
8
-
9
- self.is_discrete = is_discrete
10
- self.io_channels = io_channels
11
- self.encoded_channels = None
12
- self.downsampling_ratio = None
13
-
14
- self.enable_grad = enable_grad
15
-
16
- def encode(self, x):
17
- raise NotImplementedError
18
-
19
- def decode(self, z):
20
- raise NotImplementedError
21
-
22
- def tokenize(self, x):
23
- raise NotImplementedError
24
-
25
- def decode_tokens(self, tokens):
26
- raise NotImplementedError
27
-
28
- class AutoencoderPretransform(Pretransform):
29
- def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False):
30
- super().__init__(enable_grad=False, io_channels=model.io_channels, is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete)
31
- self.model = model
32
- self.model.requires_grad_(False).eval()
33
- self.scale=scale
34
- self.downsampling_ratio = model.downsampling_ratio
35
- self.io_channels = model.io_channels
36
- self.sample_rate = model.sample_rate
37
-
38
- self.model_half = model_half
39
- self.iterate_batch = iterate_batch
40
-
41
- self.encoded_channels = model.latent_dim
42
-
43
- self.chunked = chunked
44
- self.num_quantizers = model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None
45
- self.codebook_size = model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None
46
-
47
- if self.model_half:
48
- self.model.half()
49
-
50
- def encode(self, x, **kwargs):
51
-
52
- if self.model_half:
53
- x = x.half()
54
- self.model.to(torch.float16)
55
-
56
- encoded = self.model.encode_audio(x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
57
-
58
- if self.model_half:
59
- encoded = encoded.float()
60
-
61
- return encoded / self.scale
62
-
63
- def decode(self, z, **kwargs):
64
- z = z * self.scale
65
-
66
- if self.model_half:
67
- z = z.half()
68
- self.model.to(torch.float16)
69
-
70
- decoded = self.model.decode_audio(z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
71
-
72
- if self.model_half:
73
- decoded = decoded.float()
74
-
75
- return decoded
76
-
77
- def tokenize(self, x, **kwargs):
78
- assert self.model.is_discrete, "Cannot tokenize with a continuous model"
79
-
80
- _, info = self.model.encode(x, return_info = True, **kwargs)
81
-
82
- return info[self.model.bottleneck.tokens_id]
83
-
84
- def decode_tokens(self, tokens, **kwargs):
85
- assert self.model.is_discrete, "Cannot decode tokens with a continuous model"
86
-
87
- return self.model.decode_tokens(tokens, **kwargs)
88
-
89
- def load_state_dict(self, state_dict, strict=True):
90
- self.model.load_state_dict(state_dict, strict=strict)
91
-
92
- class WaveletPretransform(Pretransform):
93
- def __init__(self, channels, levels, wavelet):
94
- super().__init__(enable_grad=False, io_channels=channels, is_discrete=False)
95
-
96
- from .wavelets import WaveletEncode1d, WaveletDecode1d
97
-
98
- self.encoder = WaveletEncode1d(channels, levels, wavelet)
99
- self.decoder = WaveletDecode1d(channels, levels, wavelet)
100
-
101
- self.downsampling_ratio = 2 ** levels
102
- self.io_channels = channels
103
- self.encoded_channels = channels * self.downsampling_ratio
104
-
105
- def encode(self, x):
106
- return self.encoder(x)
107
-
108
- def decode(self, z):
109
- return self.decoder(z)
110
-
111
- class PQMFPretransform(Pretransform):
112
- def __init__(self, attenuation=100, num_bands=16):
113
- # TODO: Fix PQMF to take in in-channels
114
- super().__init__(enable_grad=False, io_channels=1, is_discrete=False)
115
- from .pqmf import PQMF
116
- self.pqmf = PQMF(attenuation, num_bands)
117
-
118
-
119
- def encode(self, x):
120
- # x is (Batch x Channels x Time)
121
- x = self.pqmf.forward(x)
122
- # pqmf.forward returns (Batch x Channels x Bands x Time)
123
- # but Pretransform needs Batch x Channels x Time
124
- # so concatenate channels and bands into one axis
125
- return rearrange(x, "b c n t -> b (c n) t")
126
-
127
- def decode(self, x):
128
- # x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time)
129
- x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands)
130
- # returns (Batch x Channels x Time)
131
- return self.pqmf.inverse(x)
132
-
133
- class PretrainedDACPretransform(Pretransform):
134
- def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True):
135
- super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
136
-
137
- import dac
138
-
139
- model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate)
140
-
141
- self.model = dac.DAC.load(model_path)
142
-
143
- self.quantize_on_decode = quantize_on_decode
144
-
145
- if model_type == "44khz":
146
- self.downsampling_ratio = 512
147
- else:
148
- self.downsampling_ratio = 320
149
-
150
- self.io_channels = 1
151
-
152
- self.scale = scale
153
-
154
- self.chunked = chunked
155
-
156
- self.encoded_channels = self.model.latent_dim
157
-
158
- self.num_quantizers = self.model.n_codebooks
159
-
160
- self.codebook_size = self.model.codebook_size
161
-
162
- def encode(self, x):
163
-
164
- latents = self.model.encoder(x)
165
-
166
- if self.quantize_on_decode:
167
- output = latents
168
- else:
169
- z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
170
- output = z
171
-
172
- if self.scale != 1.0:
173
- output = output / self.scale
174
-
175
- return output
176
-
177
- def decode(self, z):
178
-
179
- if self.scale != 1.0:
180
- z = z * self.scale
181
-
182
- if self.quantize_on_decode:
183
- z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
184
-
185
- return self.model.decode(z)
186
-
187
- def tokenize(self, x):
188
- return self.model.encode(x)[1]
189
-
190
- def decode_tokens(self, tokens):
191
- latents = self.model.quantizer.from_codes(tokens)
192
- return self.model.decode(latents)
193
-
194
- class AudiocraftCompressionPretransform(Pretransform):
195
- def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_decode: bool = True):
196
- super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
197
-
198
- try:
199
- from audiocraft.models import CompressionModel
200
- except ImportError:
201
- raise ImportError("Audiocraft is not installed. Please install audiocraft to use Audiocraft models.")
202
-
203
- self.model = CompressionModel.get_pretrained(model_type)
204
-
205
- self.quantize_on_decode = quantize_on_decode
206
-
207
- self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate)
208
-
209
- self.sample_rate = self.model.sample_rate
210
-
211
- self.io_channels = self.model.channels
212
-
213
- self.scale = scale
214
-
215
- #self.encoded_channels = self.model.latent_dim
216
-
217
- self.num_quantizers = self.model.num_codebooks
218
-
219
- self.codebook_size = self.model.cardinality
220
-
221
- self.model.to(torch.float16).eval().requires_grad_(False)
222
-
223
- def encode(self, x):
224
-
225
- assert False, "Audiocraft compression models do not support continuous encoding"
226
-
227
- # latents = self.model.encoder(x)
228
-
229
- # if self.quantize_on_decode:
230
- # output = latents
231
- # else:
232
- # z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
233
- # output = z
234
-
235
- # if self.scale != 1.0:
236
- # output = output / self.scale
237
-
238
- # return output
239
-
240
- def decode(self, z):
241
-
242
- assert False, "Audiocraft compression models do not support continuous decoding"
243
-
244
- # if self.scale != 1.0:
245
- # z = z * self.scale
246
-
247
- # if self.quantize_on_decode:
248
- # z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
249
-
250
- # return self.model.decode(z)
251
-
252
- def tokenize(self, x):
253
- with torch.cuda.amp.autocast(enabled=False):
254
- return self.model.encode(x.to(torch.float16))[0]
255
-
256
- def decode_tokens(self, tokens):
257
- with torch.cuda.amp.autocast(enabled=False):
258
- return self.model.decode(tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/models/transformer.py DELETED
@@ -1,805 +0,0 @@
1
- from functools import reduce, partial
2
- from packaging import version
3
-
4
- from einops import rearrange, repeat
5
- from einops.layers.torch import Rearrange
6
- import torch
7
- import torch.nn.functional as F
8
- from torch import nn, einsum
9
- from torch.cuda.amp import autocast
10
- from typing import Callable, Literal
11
-
12
- try:
13
- from flash_attn import flash_attn_func, flash_attn_kvpacked_func
14
- except ImportError as e:
15
- print(e)
16
- print('flash_attn not installed, disabling Flash Attention')
17
- flash_attn_kvpacked_func = None
18
- flash_attn_func = None
19
-
20
- try:
21
- import natten
22
- except ImportError:
23
- natten = None
24
-
25
- def checkpoint(function, *args, **kwargs):
26
- kwargs.setdefault("use_reentrant", False)
27
- return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
28
-
29
-
30
- # Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License
31
- # License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt
32
-
33
- def create_causal_mask(i, j, device):
34
- return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
35
-
36
- def or_reduce(masks):
37
- head, *body = masks
38
- for rest in body:
39
- head = head | rest
40
- return head
41
-
42
- # positional embeddings
43
-
44
- class AbsolutePositionalEmbedding(nn.Module):
45
- def __init__(self, dim, max_seq_len):
46
- super().__init__()
47
- self.scale = dim ** -0.5
48
- self.max_seq_len = max_seq_len
49
- self.emb = nn.Embedding(max_seq_len, dim)
50
-
51
- def forward(self, x, pos = None, seq_start_pos = None):
52
- seq_len, device = x.shape[1], x.device
53
- assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
54
-
55
- if pos is None:
56
- pos = torch.arange(seq_len, device = device)
57
-
58
- if seq_start_pos is not None:
59
- pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
60
-
61
- pos_emb = self.emb(pos)
62
- pos_emb = pos_emb * self.scale
63
- return pos_emb
64
-
65
- class ScaledSinusoidalEmbedding(nn.Module):
66
- def __init__(self, dim, theta = 10000):
67
- super().__init__()
68
- assert (dim % 2) == 0, 'dimension must be divisible by 2'
69
- self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
70
-
71
- half_dim = dim // 2
72
- freq_seq = torch.arange(half_dim).float() / half_dim
73
- inv_freq = theta ** -freq_seq
74
- self.register_buffer('inv_freq', inv_freq, persistent = False)
75
-
76
- def forward(self, x, pos = None, seq_start_pos = None):
77
- seq_len, device = x.shape[1], x.device
78
-
79
- if pos is None:
80
- pos = torch.arange(seq_len, device = device)
81
-
82
- if seq_start_pos is not None:
83
- pos = pos - seq_start_pos[..., None]
84
-
85
- emb = einsum('i, j -> i j', pos, self.inv_freq)
86
- emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
87
- return emb * self.scale
88
-
89
- class RotaryEmbedding(nn.Module):
90
- def __init__(
91
- self,
92
- dim,
93
- use_xpos = False,
94
- scale_base = 512,
95
- interpolation_factor = 1.,
96
- base = 10000,
97
- base_rescale_factor = 1.
98
- ):
99
- super().__init__()
100
- # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
101
- # has some connection to NTK literature
102
- # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
103
- base *= base_rescale_factor ** (dim / (dim - 2))
104
-
105
- inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
106
- self.register_buffer('inv_freq', inv_freq)
107
-
108
- assert interpolation_factor >= 1.
109
- self.interpolation_factor = interpolation_factor
110
-
111
- if not use_xpos:
112
- self.register_buffer('scale', None)
113
- return
114
-
115
- scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
116
-
117
- self.scale_base = scale_base
118
- self.register_buffer('scale', scale)
119
-
120
- def forward_from_seq_len(self, seq_len):
121
- device = self.inv_freq.device
122
-
123
- t = torch.arange(seq_len, device = device)
124
- return self.forward(t)
125
-
126
- @autocast(enabled = False)
127
- def forward(self, t):
128
- device = self.inv_freq.device
129
-
130
- t = t.to(torch.float32)
131
-
132
- t = t / self.interpolation_factor
133
-
134
- freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
135
- freqs = torch.cat((freqs, freqs), dim = -1)
136
-
137
- if self.scale is None:
138
- return freqs, 1.
139
-
140
- power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
141
- scale = self.scale ** rearrange(power, 'n -> n 1')
142
- scale = torch.cat((scale, scale), dim = -1)
143
-
144
- return freqs, scale
145
-
146
- def rotate_half(x):
147
- x = rearrange(x, '... (j d) -> ... j d', j = 2)
148
- x1, x2 = x.unbind(dim = -2)
149
- return torch.cat((-x2, x1), dim = -1)
150
-
151
- @autocast(enabled = False)
152
- def apply_rotary_pos_emb(t, freqs, scale = 1):
153
- out_dtype = t.dtype
154
-
155
- # cast to float32 if necessary for numerical stability
156
- dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
157
- rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
158
- freqs, t = freqs.to(dtype), t.to(dtype)
159
- freqs = freqs[-seq_len:, :]
160
-
161
- if t.ndim == 4 and freqs.ndim == 3:
162
- freqs = rearrange(freqs, 'b n d -> b 1 n d')
163
-
164
- # partial rotary embeddings, Wang et al. GPT-J
165
- t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
166
- t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
167
-
168
- t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
169
-
170
- return torch.cat((t, t_unrotated), dim = -1)
171
-
172
- # norms
173
- class LayerNorm(nn.Module):
174
- def __init__(self, dim, bias=False, fix_scale=False):
175
- """
176
- bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
177
- """
178
- super().__init__()
179
-
180
- if fix_scale:
181
- self.register_buffer("gamma", torch.ones(dim))
182
- else:
183
- self.gamma = nn.Parameter(torch.ones(dim))
184
-
185
- if bias:
186
- self.beta = nn.Parameter(torch.zeros(dim))
187
- else:
188
- self.register_buffer("beta", torch.zeros(dim))
189
-
190
-
191
- def forward(self, x):
192
- return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta)
193
-
194
- # feedforward
195
-
196
- class GLU(nn.Module):
197
- def __init__(
198
- self,
199
- dim_in,
200
- dim_out,
201
- activation: Callable,
202
- use_conv = False,
203
- conv_kernel_size = 3,
204
- ):
205
- super().__init__()
206
- self.act = activation
207
- self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2))
208
- self.use_conv = use_conv
209
-
210
- def forward(self, x):
211
- if self.use_conv:
212
- x = rearrange(x, 'b n d -> b d n')
213
- x = self.proj(x)
214
- x = rearrange(x, 'b d n -> b n d')
215
- else:
216
- x = self.proj(x)
217
-
218
- x, gate = x.chunk(2, dim = -1)
219
- return x * self.act(gate)
220
-
221
- class FeedForward(nn.Module):
222
- def __init__(
223
- self,
224
- dim,
225
- dim_out = None,
226
- mult = 4,
227
- no_bias = False,
228
- glu = True,
229
- use_conv = False,
230
- conv_kernel_size = 3,
231
- zero_init_output = True,
232
- ):
233
- super().__init__()
234
- inner_dim = int(dim * mult)
235
-
236
- # Default to SwiGLU
237
-
238
- activation = nn.SiLU()
239
-
240
- dim_out = dim if dim_out is None else dim_out
241
-
242
- if glu:
243
- linear_in = GLU(dim, inner_dim, activation)
244
- else:
245
- linear_in = nn.Sequential(
246
- Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
247
- nn.Linear(dim, inner_dim, bias = not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias),
248
- Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
249
- activation
250
- )
251
-
252
- linear_out = nn.Linear(inner_dim, dim_out, bias = not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias)
253
-
254
- # init last linear layer to 0
255
- if zero_init_output:
256
- nn.init.zeros_(linear_out.weight)
257
- if not no_bias:
258
- nn.init.zeros_(linear_out.bias)
259
-
260
-
261
- self.ff = nn.Sequential(
262
- linear_in,
263
- Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
264
- linear_out,
265
- Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
266
- )
267
-
268
- def forward(self, x):
269
- return self.ff(x)
270
-
271
- class Attention(nn.Module):
272
- def __init__(
273
- self,
274
- dim,
275
- dim_heads = 64,
276
- dim_context = None,
277
- causal = False,
278
- zero_init_output=True,
279
- qk_norm = False,
280
- natten_kernel_size = None
281
- ):
282
- super().__init__()
283
- self.dim = dim
284
- self.dim_heads = dim_heads
285
- self.causal = causal
286
-
287
- dim_kv = dim_context if dim_context is not None else dim
288
-
289
- self.num_heads = dim // dim_heads
290
- self.kv_heads = dim_kv // dim_heads
291
-
292
- if dim_context is not None:
293
- self.to_q = nn.Linear(dim, dim, bias=False)
294
- self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False)
295
- else:
296
- self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
297
-
298
- self.to_out = nn.Linear(dim, dim, bias=False)
299
-
300
- if zero_init_output:
301
- nn.init.zeros_(self.to_out.weight)
302
-
303
- self.qk_norm = qk_norm
304
-
305
- # Using 1d neighborhood attention
306
- self.natten_kernel_size = natten_kernel_size
307
- if natten_kernel_size is not None:
308
- return
309
-
310
- self.use_pt_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
311
-
312
- self.use_fa_flash = torch.cuda.is_available() and flash_attn_func is not None
313
-
314
- self.sdp_kwargs = dict(
315
- enable_flash = True,
316
- enable_math = True,
317
- enable_mem_efficient = True
318
- )
319
-
320
- def flash_attn(
321
- self,
322
- q,
323
- k,
324
- v,
325
- mask = None,
326
- causal = None
327
- ):
328
- batch, heads, q_len, _, k_len, device = *q.shape, k.shape[-2], q.device
329
- kv_heads = k.shape[1]
330
- # Recommended for multi-query single-key-value attention by Tri Dao
331
- # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
332
-
333
- if heads != kv_heads:
334
- # Repeat interleave kv_heads to match q_heads
335
- heads_per_kv_head = heads // kv_heads
336
- k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
337
-
338
- if k.ndim == 3:
339
- k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
340
-
341
- if v.ndim == 3:
342
- v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
343
-
344
- causal = self.causal if causal is None else causal
345
-
346
- if q_len == 1 and causal:
347
- causal = False
348
-
349
- if mask is not None:
350
- assert mask.ndim == 4
351
- mask = mask.expand(batch, heads, q_len, k_len)
352
-
353
- # handle kv cache - this should be bypassable in updated flash attention 2
354
-
355
- if k_len > q_len and causal:
356
- causal_mask = self.create_causal_mask(q_len, k_len, device = device)
357
- if mask is None:
358
- mask = ~causal_mask
359
- else:
360
- mask = mask & ~causal_mask
361
- causal = False
362
-
363
- # manually handle causal mask, if another mask was given
364
-
365
- row_is_entirely_masked = None
366
-
367
- if mask is not None and causal:
368
- causal_mask = self.create_causal_mask(q_len, k_len, device = device)
369
- mask = mask & ~causal_mask
370
-
371
- # protect against an entire row being masked out
372
-
373
- row_is_entirely_masked = ~mask.any(dim = -1)
374
- mask[..., 0] = mask[..., 0] | row_is_entirely_masked
375
-
376
- causal = False
377
-
378
- with torch.backends.cuda.sdp_kernel(**self.sdp_kwargs):
379
- out = F.scaled_dot_product_attention(
380
- q, k, v,
381
- attn_mask = mask,
382
- is_causal = causal
383
- )
384
-
385
- # for a row that is entirely masked out, should zero out the output of that row token
386
-
387
- if row_is_entirely_masked is not None:
388
- out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
389
-
390
- return out
391
-
392
- def forward(
393
- self,
394
- x,
395
- context = None,
396
- mask = None,
397
- context_mask = None,
398
- rotary_pos_emb = None,
399
- causal = None
400
- ):
401
- h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
402
-
403
- kv_input = context if has_context else x
404
-
405
- if hasattr(self, 'to_q'):
406
- # Use separate linear projections for q and k/v
407
- q = self.to_q(x)
408
- q = rearrange(q, 'b n (h d) -> b h n d', h = h)
409
-
410
- k, v = self.to_kv(kv_input).chunk(2, dim=-1)
411
-
412
- k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
413
- else:
414
- # Use fused linear projection
415
- q, k, v = self.to_qkv(x).chunk(3, dim=-1)
416
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
417
-
418
- # Normalize q and k for cosine sim attention
419
- if self.qk_norm:
420
- q = F.normalize(q, dim=-1)
421
- k = F.normalize(k, dim=-1)
422
-
423
- if rotary_pos_emb is not None and not has_context:
424
- freqs, _ = rotary_pos_emb
425
-
426
- q_dtype = q.dtype
427
- k_dtype = k.dtype
428
-
429
- q = q.to(torch.float32)
430
- k = k.to(torch.float32)
431
- freqs = freqs.to(torch.float32)
432
-
433
- q = apply_rotary_pos_emb(q, freqs)
434
- k = apply_rotary_pos_emb(k, freqs)
435
-
436
- q = q.to(q_dtype)
437
- k = k.to(k_dtype)
438
-
439
- input_mask = context_mask
440
-
441
- if input_mask is None and not has_context:
442
- input_mask = mask
443
-
444
- # determine masking
445
- masks = []
446
- final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account
447
-
448
- if input_mask is not None:
449
- input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
450
- masks.append(~input_mask)
451
-
452
- # Other masks will be added here later
453
-
454
- if len(masks) > 0:
455
- final_attn_mask = ~or_reduce(masks)
456
-
457
- n, device = q.shape[-2], q.device
458
-
459
- causal = self.causal if causal is None else causal
460
-
461
- if n == 1 and causal:
462
- causal = False
463
-
464
- if self.natten_kernel_size is not None:
465
- if natten is None:
466
- raise ImportError('natten not installed, please install natten to use neighborhood attention')
467
-
468
- dtype_in = q.dtype
469
- q, k, v = map(lambda t: t.to(torch.float32), (q, k, v))
470
-
471
- attn = natten.functional.natten1dqk(q, k, kernel_size = self.natten_kernel_size, dilation=1)
472
-
473
- if final_attn_mask is not None:
474
- attn = attn.masked_fill(final_attn_mask, -torch.finfo(attn.dtype).max)
475
-
476
- attn = F.softmax(attn, dim=-1, dtype=torch.float32)
477
-
478
- out = natten.functional.natten1dav(attn, v, kernel_size = self.natten_kernel_size, dilation=1).to(dtype_in)
479
-
480
- # Prioritize Flash Attention 2
481
- elif self.use_fa_flash:
482
- assert final_attn_mask is None, 'masking not yet supported for Flash Attention 2'
483
- # Flash Attention 2 requires FP16 inputs
484
- fa_dtype_in = q.dtype
485
- q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d').to(torch.float16), (q, k, v))
486
-
487
- out = flash_attn_func(q, k, v, causal = causal)
488
-
489
- out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d')
490
-
491
- # Fall back to PyTorch implementation
492
- elif self.use_pt_flash:
493
- out = self.flash_attn(q, k, v, causal = causal, mask = final_attn_mask)
494
-
495
- else:
496
- # Fall back to custom implementation
497
-
498
- if h != kv_h:
499
- # Repeat interleave kv_heads to match q_heads
500
- heads_per_kv_head = h // kv_h
501
- k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
502
-
503
- scale = 1. / (q.shape[-1] ** 0.5)
504
-
505
- kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
506
-
507
- dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
508
-
509
- i, j, dtype = *dots.shape[-2:], dots.dtype
510
-
511
- mask_value = -torch.finfo(dots.dtype).max
512
-
513
- if final_attn_mask is not None:
514
- dots = dots.masked_fill(~final_attn_mask, mask_value)
515
-
516
- if causal:
517
- causal_mask = self.create_causal_mask(i, j, device = device)
518
- dots = dots.masked_fill(causal_mask, mask_value)
519
-
520
- attn = F.softmax(dots, dim=-1, dtype=torch.float32)
521
- attn = attn.type(dtype)
522
-
523
- out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
524
-
525
- # merge heads
526
- out = rearrange(out, ' b h n d -> b n (h d)')
527
-
528
- # Communicate between heads
529
-
530
- # with autocast(enabled = False):
531
- # out_dtype = out.dtype
532
- # out = out.to(torch.float32)
533
- # out = self.to_out(out).to(out_dtype)
534
- out = self.to_out(out)
535
-
536
- if mask is not None:
537
- mask = rearrange(mask, 'b n -> b n 1')
538
- out = out.masked_fill(~mask, 0.)
539
-
540
- return out
541
-
542
- class ConformerModule(nn.Module):
543
- def __init__(
544
- self,
545
- dim,
546
- norm_kwargs = {},
547
- ):
548
-
549
- super().__init__()
550
-
551
- self.dim = dim
552
-
553
- self.in_norm = LayerNorm(dim, **norm_kwargs)
554
- self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
555
- self.glu = GLU(dim, dim, nn.SiLU())
556
- self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
557
- self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
558
- self.swish = nn.SiLU()
559
- self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
560
-
561
- def forward(self, x):
562
- x = self.in_norm(x)
563
- x = rearrange(x, 'b n d -> b d n')
564
- x = self.pointwise_conv(x)
565
- x = rearrange(x, 'b d n -> b n d')
566
- x = self.glu(x)
567
- x = rearrange(x, 'b n d -> b d n')
568
- x = self.depthwise_conv(x)
569
- x = rearrange(x, 'b d n -> b n d')
570
- x = self.mid_norm(x)
571
- x = self.swish(x)
572
- x = rearrange(x, 'b n d -> b d n')
573
- x = self.pointwise_conv_2(x)
574
- x = rearrange(x, 'b d n -> b n d')
575
-
576
- return x
577
-
578
- class TransformerBlock(nn.Module):
579
- def __init__(
580
- self,
581
- dim,
582
- dim_heads = 64,
583
- cross_attend = False,
584
- dim_context = None,
585
- global_cond_dim = None,
586
- causal = False,
587
- zero_init_branch_outputs = True,
588
- conformer = False,
589
- layer_ix = -1,
590
- remove_norms = False,
591
- attn_kwargs = {},
592
- ff_kwargs = {},
593
- norm_kwargs = {}
594
- ):
595
-
596
- super().__init__()
597
- self.dim = dim
598
- self.dim_heads = dim_heads
599
- self.cross_attend = cross_attend
600
- self.dim_context = dim_context
601
- self.causal = causal
602
-
603
- self.pre_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
604
-
605
- self.self_attn = Attention(
606
- dim,
607
- dim_heads = dim_heads,
608
- causal = causal,
609
- zero_init_output=zero_init_branch_outputs,
610
- **attn_kwargs
611
- )
612
-
613
- if cross_attend:
614
- self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
615
- self.cross_attn = Attention(
616
- dim,
617
- dim_heads = dim_heads,
618
- dim_context=dim_context,
619
- causal = causal,
620
- zero_init_output=zero_init_branch_outputs,
621
- **attn_kwargs
622
- )
623
-
624
- self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
625
- self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs)
626
-
627
- self.layer_ix = layer_ix
628
-
629
- self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None
630
-
631
- self.global_cond_dim = global_cond_dim
632
-
633
- if global_cond_dim is not None:
634
- self.to_scale_shift_gate = nn.Sequential(
635
- nn.SiLU(),
636
- nn.Linear(global_cond_dim, dim * 6, bias=False)
637
- )
638
-
639
- nn.init.zeros_(self.to_scale_shift_gate[1].weight)
640
- #nn.init.zeros_(self.to_scale_shift_gate_self[1].bias)
641
-
642
- def forward(
643
- self,
644
- x,
645
- context = None,
646
- global_cond=None,
647
- mask = None,
648
- context_mask = None,
649
- rotary_pos_emb = None
650
- ):
651
- if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
652
-
653
- scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1)
654
-
655
- # self-attention with adaLN
656
- residual = x
657
- x = self.pre_norm(x)
658
- x = x * (1 + scale_self) + shift_self
659
- x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
660
- x = x * torch.sigmoid(1 - gate_self)
661
- x = x + residual
662
-
663
- if context is not None:
664
- x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
665
-
666
- if self.conformer is not None:
667
- x = x + self.conformer(x)
668
-
669
- # feedforward with adaLN
670
- residual = x
671
- x = self.ff_norm(x)
672
- x = x * (1 + scale_ff) + shift_ff
673
- x = self.ff(x)
674
- x = x * torch.sigmoid(1 - gate_ff)
675
- x = x + residual
676
-
677
- else:
678
- x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
679
-
680
- if context is not None:
681
- x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
682
-
683
- if self.conformer is not None:
684
- x = x + self.conformer(x)
685
-
686
- x = x + self.ff(self.ff_norm(x))
687
-
688
- return x
689
-
690
- class ContinuousTransformer(nn.Module):
691
- def __init__(
692
- self,
693
- dim,
694
- depth,
695
- *,
696
- dim_in = None,
697
- dim_out = None,
698
- dim_heads = 64,
699
- cross_attend=False,
700
- cond_token_dim=None,
701
- global_cond_dim=None,
702
- causal=False,
703
- rotary_pos_emb=True,
704
- zero_init_branch_outputs=True,
705
- conformer=False,
706
- use_sinusoidal_emb=False,
707
- use_abs_pos_emb=False,
708
- abs_pos_emb_max_length=10000,
709
- **kwargs
710
- ):
711
-
712
- super().__init__()
713
-
714
- self.dim = dim
715
- self.depth = depth
716
- self.causal = causal
717
- self.layers = nn.ModuleList([])
718
-
719
- self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity()
720
- self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity()
721
-
722
- if rotary_pos_emb:
723
- self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
724
- else:
725
- self.rotary_pos_emb = None
726
-
727
- self.use_sinusoidal_emb = use_sinusoidal_emb
728
- if use_sinusoidal_emb:
729
- self.pos_emb = ScaledSinusoidalEmbedding(dim)
730
-
731
- self.use_abs_pos_emb = use_abs_pos_emb
732
- if use_abs_pos_emb:
733
- self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length)
734
-
735
- for i in range(depth):
736
- self.layers.append(
737
- TransformerBlock(
738
- dim,
739
- dim_heads = dim_heads,
740
- cross_attend = cross_attend,
741
- dim_context = cond_token_dim,
742
- global_cond_dim = global_cond_dim,
743
- causal = causal,
744
- zero_init_branch_outputs = zero_init_branch_outputs,
745
- conformer=conformer,
746
- layer_ix=i,
747
- **kwargs
748
- )
749
- )
750
-
751
- def forward(
752
- self,
753
- x,
754
- mask = None,
755
- prepend_embeds = None,
756
- prepend_mask = None,
757
- global_cond = None,
758
- return_info = False,
759
- **kwargs
760
- ):
761
- batch, seq, device = *x.shape[:2], x.device
762
-
763
- info = {
764
- "hidden_states": [],
765
- }
766
-
767
- x = self.project_in(x)
768
-
769
- if prepend_embeds is not None:
770
- prepend_length, prepend_dim = prepend_embeds.shape[1:]
771
-
772
- assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
773
-
774
- x = torch.cat((prepend_embeds, x), dim = -2)
775
-
776
- if prepend_mask is not None or mask is not None:
777
- mask = mask if mask is not None else torch.ones((batch, seq), device = device, dtype = torch.bool)
778
- prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), device = device, dtype = torch.bool)
779
-
780
- mask = torch.cat((prepend_mask, mask), dim = -1)
781
-
782
- # Attention layers
783
-
784
- if self.rotary_pos_emb is not None:
785
- rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
786
- else:
787
- rotary_pos_emb = None
788
-
789
- if self.use_sinusoidal_emb or self.use_abs_pos_emb:
790
- x = x + self.pos_emb(x)
791
-
792
- # Iterate over the transformer layers
793
- for layer in self.layers:
794
- #x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
795
- x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
796
-
797
- if return_info:
798
- info["hidden_states"].append(x)
799
-
800
- x = self.project_out(x)
801
-
802
- if return_info:
803
- return x, info
804
-
805
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/models/utils.py DELETED
@@ -1,89 +0,0 @@
1
- import torch
2
- from safetensors.torch import load_file
3
-
4
- from torch.nn.utils import remove_weight_norm
5
-
6
- def load_ckpt_state_dict(ckpt_path):
7
- if ckpt_path.endswith(".safetensors"):
8
- state_dict = load_file(ckpt_path)
9
- else:
10
- state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
11
-
12
- return state_dict
13
-
14
- def remove_weight_norm_from_model(model):
15
- for module in model.modules():
16
- if hasattr(module, "weight"):
17
- print(f"Removing weight norm from {module}")
18
- remove_weight_norm(module)
19
-
20
- return model
21
-
22
- # Sampling functions copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/utils/utils.py under MIT license
23
- # License can be found in LICENSES/LICENSE_META.txt
24
-
25
- def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
26
- """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
27
-
28
- Args:
29
- input (torch.Tensor): The input tensor containing probabilities.
30
- num_samples (int): Number of samples to draw.
31
- replacement (bool): Whether to draw with replacement or not.
32
- Keywords args:
33
- generator (torch.Generator): A pseudorandom number generator for sampling.
34
- Returns:
35
- torch.Tensor: Last dimension contains num_samples indices
36
- sampled from the multinomial probability distribution
37
- located in the last dimension of tensor input.
38
- """
39
-
40
- if num_samples == 1:
41
- q = torch.empty_like(input).exponential_(1, generator=generator)
42
- return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64)
43
-
44
- input_ = input.reshape(-1, input.shape[-1])
45
- output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
46
- output = output_.reshape(*list(input.shape[:-1]), -1)
47
- return output
48
-
49
-
50
- def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
51
- """Sample next token from top K values along the last dimension of the input probs tensor.
52
-
53
- Args:
54
- probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
55
- k (int): The k in “top-k”.
56
- Returns:
57
- torch.Tensor: Sampled tokens.
58
- """
59
- top_k_value, _ = torch.topk(probs, k, dim=-1)
60
- min_value_top_k = top_k_value[..., [-1]]
61
- probs *= (probs >= min_value_top_k).float()
62
- probs.div_(probs.sum(dim=-1, keepdim=True))
63
- next_token = multinomial(probs, num_samples=1)
64
- return next_token
65
-
66
-
67
- def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
68
- """Sample next token from top P probabilities along the last dimension of the input probs tensor.
69
-
70
- Args:
71
- probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
72
- p (int): The p in “top-p”.
73
- Returns:
74
- torch.Tensor: Sampled tokens.
75
- """
76
- probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
77
- probs_sum = torch.cumsum(probs_sort, dim=-1)
78
- mask = probs_sum - probs_sort > p
79
- probs_sort *= (~mask).float()
80
- probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
81
- next_token = multinomial(probs_sort, num_samples=1)
82
- next_token = torch.gather(probs_idx, -1, next_token)
83
- return next_token
84
-
85
- def next_power_of_two(n):
86
- return 2 ** (n - 1).bit_length()
87
-
88
- def next_multiple_of_64(n):
89
- return ((n + 63) // 64) * 64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/models/wavelets.py DELETED
@@ -1,82 +0,0 @@
1
- """The 1D discrete wavelet transform for PyTorch."""
2
-
3
- from einops import rearrange
4
- import pywt
5
- import torch
6
- from torch import nn
7
- from torch.nn import functional as F
8
- from typing import Literal
9
-
10
-
11
- def get_filter_bank(wavelet):
12
- filt = torch.tensor(pywt.Wavelet(wavelet).filter_bank)
13
- if wavelet.startswith("bior") and torch.all(filt[:, 0] == 0):
14
- filt = filt[:, 1:]
15
- return filt
16
-
17
- class WaveletEncode1d(nn.Module):
18
- def __init__(self,
19
- channels,
20
- levels,
21
- wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"):
22
- super().__init__()
23
- self.wavelet = wavelet
24
- self.channels = channels
25
- self.levels = levels
26
- filt = get_filter_bank(wavelet)
27
- assert filt.shape[-1] % 2 == 1
28
- kernel = filt[:2, None]
29
- kernel = torch.flip(kernel, dims=(-1,))
30
- index_i = torch.repeat_interleave(torch.arange(2), channels)
31
- index_j = torch.tile(torch.arange(channels), (2,))
32
- kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1])
33
- kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0]
34
- self.register_buffer("kernel", kernel_final)
35
-
36
- def forward(self, x):
37
- for i in range(self.levels):
38
- low, rest = x[:, : self.channels], x[:, self.channels :]
39
- pad = self.kernel.shape[-1] // 2
40
- low = F.pad(low, (pad, pad), "reflect")
41
- low = F.conv1d(low, self.kernel, stride=2)
42
- rest = rearrange(
43
- rest, "n (c c2) (l l2) -> n (c l2 c2) l", l2=2, c2=self.channels
44
- )
45
- x = torch.cat([low, rest], dim=1)
46
- return x
47
-
48
-
49
- class WaveletDecode1d(nn.Module):
50
- def __init__(self,
51
- channels,
52
- levels,
53
- wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"):
54
- super().__init__()
55
- self.wavelet = wavelet
56
- self.channels = channels
57
- self.levels = levels
58
- filt = get_filter_bank(wavelet)
59
- assert filt.shape[-1] % 2 == 1
60
- kernel = filt[2:, None]
61
- index_i = torch.repeat_interleave(torch.arange(2), channels)
62
- index_j = torch.tile(torch.arange(channels), (2,))
63
- kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1])
64
- kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0]
65
- self.register_buffer("kernel", kernel_final)
66
-
67
- def forward(self, x):
68
- for i in range(self.levels):
69
- low, rest = x[:, : self.channels * 2], x[:, self.channels * 2 :]
70
- pad = self.kernel.shape[-1] // 2 + 2
71
- low = rearrange(low, "n (l2 c) l -> n c (l l2)", l2=2)
72
- low = F.pad(low, (pad, pad), "reflect")
73
- low = rearrange(low, "n c (l l2) -> n (l2 c) l", l2=2)
74
- low = F.conv_transpose1d(
75
- low, self.kernel, stride=2, padding=self.kernel.shape[-1] // 2
76
- )
77
- low = low[..., pad - 1 : -pad]
78
- rest = rearrange(
79
- rest, "n (c l2 c2) l -> n (c c2) (l l2)", l2=2, c2=self.channels
80
- )
81
- x = torch.cat([low, rest], dim=1)
82
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/training/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .factory import create_training_wrapper_from_config, create_demo_callback_from_config
 
 
stable/build/lib/stable_audio_tools/training/autoencoders.py DELETED
@@ -1,477 +0,0 @@
1
- import torch
2
- import torchaudio
3
- import wandb
4
- from einops import rearrange
5
- from safetensors.torch import save_file, save_model
6
- from ema_pytorch import EMA
7
- from .losses.auraloss import SumAndDifferenceSTFTLoss, MultiResolutionSTFTLoss
8
- import pytorch_lightning as pl
9
- from ..models.autoencoders import AudioAutoencoder
10
- from ..models.discriminators import EncodecDiscriminator, OobleckDiscriminator, DACGANLoss
11
- from ..models.bottleneck import VAEBottleneck, RVQBottleneck, DACRVQBottleneck, DACRVQVAEBottleneck, RVQVAEBottleneck, WassersteinBottleneck
12
- from .losses import MultiLoss, AuralossLoss, ValueLoss, L1Loss
13
- from .utils import create_optimizer_from_config, create_scheduler_from_config
14
-
15
-
16
- from pytorch_lightning.utilities.rank_zero import rank_zero_only
17
- from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image
18
-
19
- class AutoencoderTrainingWrapper(pl.LightningModule):
20
- def __init__(
21
- self,
22
- autoencoder: AudioAutoencoder,
23
- lr: float = 1e-4,
24
- warmup_steps: int = 0,
25
- encoder_freeze_on_warmup: bool = False,
26
- sample_rate=48000,
27
- loss_config: dict = None,
28
- optimizer_configs: dict = None,
29
- use_ema: bool = True,
30
- ema_copy = None,
31
- force_input_mono = False,
32
- latent_mask_ratio = 0.0,
33
- teacher_model: AudioAutoencoder = None
34
- ):
35
- super().__init__()
36
-
37
- self.automatic_optimization = False
38
-
39
- self.autoencoder = autoencoder
40
-
41
- self.warmed_up = False
42
- self.warmup_steps = warmup_steps
43
- self.encoder_freeze_on_warmup = encoder_freeze_on_warmup
44
- self.lr = lr
45
-
46
- self.force_input_mono = force_input_mono
47
-
48
- self.teacher_model = teacher_model
49
-
50
- if optimizer_configs is None:
51
- optimizer_configs ={
52
- "autoencoder": {
53
- "optimizer": {
54
- "type": "AdamW",
55
- "config": {
56
- "lr": lr,
57
- "betas": (.8, .99)
58
- }
59
- }
60
- },
61
- "discriminator": {
62
- "optimizer": {
63
- "type": "AdamW",
64
- "config": {
65
- "lr": lr,
66
- "betas": (.8, .99)
67
- }
68
- }
69
- }
70
-
71
- }
72
-
73
- self.optimizer_configs = optimizer_configs
74
-
75
- if loss_config is None:
76
- scales = [2048, 1024, 512, 256, 128, 64, 32]
77
- hop_sizes = []
78
- win_lengths = []
79
- overlap = 0.75
80
- for s in scales:
81
- hop_sizes.append(int(s * (1 - overlap)))
82
- win_lengths.append(s)
83
-
84
- loss_config = {
85
- "discriminator": {
86
- "type": "encodec",
87
- "config": {
88
- "n_ffts": scales,
89
- "hop_lengths": hop_sizes,
90
- "win_lengths": win_lengths,
91
- "filters": 32
92
- },
93
- "weights": {
94
- "adversarial": 0.1,
95
- "feature_matching": 5.0,
96
- }
97
- },
98
- "spectral": {
99
- "type": "mrstft",
100
- "config": {
101
- "fft_sizes": scales,
102
- "hop_sizes": hop_sizes,
103
- "win_lengths": win_lengths,
104
- "perceptual_weighting": True
105
- },
106
- "weights": {
107
- "mrstft": 1.0,
108
- }
109
- },
110
- "time": {
111
- "type": "l1",
112
- "config": {},
113
- "weights": {
114
- "l1": 0.0,
115
- }
116
- }
117
- }
118
-
119
- self.loss_config = loss_config
120
-
121
- # Spectral reconstruction loss
122
-
123
- stft_loss_args = loss_config['spectral']['config']
124
-
125
- if self.autoencoder.out_channels == 2:
126
- self.sdstft = SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
127
- self.lrstft = MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
128
- else:
129
- self.sdstft = MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
130
-
131
- # Discriminator
132
-
133
- if loss_config['discriminator']['type'] == 'oobleck':
134
- self.discriminator = OobleckDiscriminator(**loss_config['discriminator']['config'])
135
- elif loss_config['discriminator']['type'] == 'encodec':
136
- self.discriminator = EncodecDiscriminator(in_channels=self.autoencoder.out_channels, **loss_config['discriminator']['config'])
137
- elif loss_config['discriminator']['type'] == 'dac':
138
- self.discriminator = DACGANLoss(channels=self.autoencoder.out_channels, sample_rate=sample_rate, **loss_config['discriminator']['config'])
139
-
140
- self.gen_loss_modules = []
141
-
142
- # Adversarial and feature matching losses
143
- self.gen_loss_modules += [
144
- ValueLoss(key='loss_adv', weight=self.loss_config['discriminator']['weights']['adversarial'], name='loss_adv'),
145
- ValueLoss(key='feature_matching_distance', weight=self.loss_config['discriminator']['weights']['feature_matching'], name='feature_matching'),
146
- ]
147
-
148
- if self.teacher_model is not None:
149
- # Distillation losses
150
-
151
- stft_loss_weight = self.loss_config['spectral']['weights']['mrstft'] * 0.25
152
- self.gen_loss_modules += [
153
- AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=stft_loss_weight), # Reconstruction loss
154
- AuralossLoss(self.sdstft, 'decoded', 'teacher_decoded', name='mrstft_loss_distill', weight=stft_loss_weight), # Distilled model's decoder is compatible with teacher's decoder
155
- AuralossLoss(self.sdstft, 'reals', 'own_latents_teacher_decoded', name='mrstft_loss_own_latents_teacher', weight=stft_loss_weight), # Distilled model's encoder is compatible with teacher's decoder
156
- AuralossLoss(self.sdstft, 'reals', 'teacher_latents_own_decoded', name='mrstft_loss_teacher_latents_own', weight=stft_loss_weight) # Teacher's encoder is compatible with distilled model's decoder
157
- ]
158
-
159
- else:
160
-
161
- # Reconstruction loss
162
- self.gen_loss_modules += [
163
- AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=self.loss_config['spectral']['weights']['mrstft']),
164
- ]
165
-
166
- if self.autoencoder.out_channels == 2:
167
-
168
- # Add left and right channel reconstruction losses in addition to the sum and difference
169
- self.gen_loss_modules += [
170
- AuralossLoss(self.lrstft, 'reals_left', 'decoded_left', name='stft_loss_left', weight=self.loss_config['spectral']['weights']['mrstft']/2),
171
- AuralossLoss(self.lrstft, 'reals_right', 'decoded_right', name='stft_loss_right', weight=self.loss_config['spectral']['weights']['mrstft']/2),
172
- ]
173
-
174
- self.gen_loss_modules += [
175
- AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=self.loss_config['spectral']['weights']['mrstft']),
176
- ]
177
-
178
- if self.loss_config['time']['weights']['l1'] > 0.0:
179
- self.gen_loss_modules.append(L1Loss(key_a='reals', key_b='decoded', weight=self.loss_config['time']['weights']['l1'], name='l1_time_loss'))
180
-
181
- if self.autoencoder.bottleneck is not None:
182
- self.gen_loss_modules += create_loss_modules_from_bottleneck(self.autoencoder.bottleneck, self.loss_config)
183
-
184
- self.losses_gen = MultiLoss(self.gen_loss_modules)
185
-
186
- self.disc_loss_modules = [
187
- ValueLoss(key='loss_dis', weight=1.0, name='discriminator_loss'),
188
- ]
189
-
190
- self.losses_disc = MultiLoss(self.disc_loss_modules)
191
-
192
- # Set up EMA for model weights
193
- self.autoencoder_ema = None
194
-
195
- self.use_ema = use_ema
196
-
197
- if self.use_ema:
198
- self.autoencoder_ema = EMA(
199
- self.autoencoder,
200
- ema_model=ema_copy,
201
- beta=0.9999,
202
- power=3/4,
203
- update_every=1,
204
- update_after_step=1
205
- )
206
-
207
- self.latent_mask_ratio = latent_mask_ratio
208
-
209
- def configure_optimizers(self):
210
-
211
- opt_gen = create_optimizer_from_config(self.optimizer_configs['autoencoder']['optimizer'], self.autoencoder.parameters())
212
- opt_disc = create_optimizer_from_config(self.optimizer_configs['discriminator']['optimizer'], self.discriminator.parameters())
213
-
214
- if "scheduler" in self.optimizer_configs['autoencoder'] and "scheduler" in self.optimizer_configs['discriminator']:
215
- sched_gen = create_scheduler_from_config(self.optimizer_configs['autoencoder']['scheduler'], opt_gen)
216
- sched_disc = create_scheduler_from_config(self.optimizer_configs['discriminator']['scheduler'], opt_disc)
217
- return [opt_gen, opt_disc], [sched_gen, sched_disc]
218
-
219
- return [opt_gen, opt_disc]
220
-
221
- def training_step(self, batch, batch_idx):
222
- reals, _ = batch
223
-
224
- # Remove extra dimension added by WebDataset
225
- if reals.ndim == 4 and reals.shape[0] == 1:
226
- reals = reals[0]
227
-
228
- if self.global_step >= self.warmup_steps:
229
- self.warmed_up = True
230
-
231
- loss_info = {}
232
-
233
- loss_info["reals"] = reals
234
-
235
- encoder_input = reals
236
-
237
- if self.force_input_mono and encoder_input.shape[1] > 1:
238
- encoder_input = encoder_input.mean(dim=1, keepdim=True)
239
-
240
- loss_info["encoder_input"] = encoder_input
241
-
242
- data_std = encoder_input.std()
243
-
244
- if self.warmed_up and self.encoder_freeze_on_warmup:
245
- with torch.no_grad():
246
- latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True)
247
- else:
248
- latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True)
249
-
250
- loss_info["latents"] = latents
251
-
252
- loss_info.update(encoder_info)
253
-
254
- # Encode with teacher model for distillation
255
- if self.teacher_model is not None:
256
- with torch.no_grad():
257
- teacher_latents = self.teacher_model.encode(encoder_input, return_info=False)
258
- loss_info['teacher_latents'] = teacher_latents
259
-
260
- # Optionally mask out some latents for noise resistance
261
- if self.latent_mask_ratio > 0.0:
262
- mask = torch.rand_like(latents) < self.latent_mask_ratio
263
- latents = torch.where(mask, torch.zeros_like(latents), latents)
264
-
265
- decoded = self.autoencoder.decode(latents)
266
-
267
- loss_info["decoded"] = decoded
268
-
269
- if self.autoencoder.out_channels == 2:
270
- loss_info["decoded_left"] = decoded[:, 0:1, :]
271
- loss_info["decoded_right"] = decoded[:, 1:2, :]
272
- loss_info["reals_left"] = reals[:, 0:1, :]
273
- loss_info["reals_right"] = reals[:, 1:2, :]
274
-
275
- # Distillation
276
- if self.teacher_model is not None:
277
- with torch.no_grad():
278
- teacher_decoded = self.teacher_model.decode(teacher_latents)
279
- own_latents_teacher_decoded = self.teacher_model.decode(latents) #Distilled model's latents decoded by teacher
280
- teacher_latents_own_decoded = self.autoencoder.decode(teacher_latents) #Teacher's latents decoded by distilled model
281
-
282
- loss_info['teacher_decoded'] = teacher_decoded
283
- loss_info['own_latents_teacher_decoded'] = own_latents_teacher_decoded
284
- loss_info['teacher_latents_own_decoded'] = teacher_latents_own_decoded
285
-
286
-
287
- if self.warmed_up:
288
- loss_dis, loss_adv, feature_matching_distance = self.discriminator.loss(reals, decoded)
289
- else:
290
- loss_dis = torch.tensor(0.).to(reals)
291
- loss_adv = torch.tensor(0.).to(reals)
292
- feature_matching_distance = torch.tensor(0.).to(reals)
293
-
294
- loss_info["loss_dis"] = loss_dis
295
- loss_info["loss_adv"] = loss_adv
296
- loss_info["feature_matching_distance"] = feature_matching_distance
297
-
298
- opt_gen, opt_disc = self.optimizers()
299
-
300
- lr_schedulers = self.lr_schedulers()
301
-
302
- sched_gen = None
303
- sched_disc = None
304
-
305
- if lr_schedulers is not None:
306
- sched_gen, sched_disc = lr_schedulers
307
-
308
- # Train the discriminator
309
- if self.global_step % 2 and self.warmed_up:
310
- loss, losses = self.losses_disc(loss_info)
311
-
312
- log_dict = {
313
- 'train/disc_lr': opt_disc.param_groups[0]['lr']
314
- }
315
-
316
- opt_disc.zero_grad()
317
- self.manual_backward(loss)
318
- opt_disc.step()
319
-
320
- if sched_disc is not None:
321
- # sched step every step
322
- sched_disc.step()
323
-
324
- # Train the generator
325
- else:
326
-
327
- loss, losses = self.losses_gen(loss_info)
328
-
329
- if self.use_ema:
330
- self.autoencoder_ema.update()
331
-
332
- opt_gen.zero_grad()
333
- self.manual_backward(loss)
334
- opt_gen.step()
335
-
336
- if sched_gen is not None:
337
- # scheduler step every step
338
- sched_gen.step()
339
-
340
- log_dict = {
341
- 'train/loss': loss.detach(),
342
- 'train/latent_std': latents.std().detach(),
343
- 'train/data_std': data_std.detach(),
344
- 'train/gen_lr': opt_gen.param_groups[0]['lr']
345
- }
346
-
347
- for loss_name, loss_value in losses.items():
348
- log_dict[f'train/{loss_name}'] = loss_value.detach()
349
-
350
- self.log_dict(log_dict, prog_bar=True, on_step=True)
351
-
352
- return loss
353
-
354
- def export_model(self, path, use_safetensors=False):
355
- if self.autoencoder_ema is not None:
356
- model = self.autoencoder_ema.ema_model
357
- else:
358
- model = self.autoencoder
359
-
360
- if use_safetensors:
361
- save_model(model, path)
362
- else:
363
- torch.save({"state_dict": model.state_dict()}, path)
364
-
365
-
366
- class AutoencoderDemoCallback(pl.Callback):
367
- def __init__(
368
- self,
369
- demo_dl,
370
- demo_every=2000,
371
- sample_size=65536,
372
- sample_rate=48000
373
- ):
374
- super().__init__()
375
- self.demo_every = demo_every
376
- self.demo_samples = sample_size
377
- self.demo_dl = iter(demo_dl)
378
- self.sample_rate = sample_rate
379
- self.last_demo_step = -1
380
-
381
- @rank_zero_only
382
- @torch.no_grad()
383
- def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx):
384
- if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
385
- return
386
-
387
- self.last_demo_step = trainer.global_step
388
-
389
- module.eval()
390
-
391
- try:
392
- demo_reals, _ = next(self.demo_dl)
393
-
394
- # Remove extra dimension added by WebDataset
395
- if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
396
- demo_reals = demo_reals[0]
397
-
398
- encoder_input = demo_reals
399
-
400
- encoder_input = encoder_input.to(module.device)
401
-
402
- if module.force_input_mono:
403
- encoder_input = encoder_input.mean(dim=1, keepdim=True)
404
-
405
- demo_reals = demo_reals.to(module.device)
406
-
407
- with torch.no_grad():
408
- if module.use_ema:
409
-
410
- latents = module.autoencoder_ema.ema_model.encode(encoder_input)
411
-
412
- fakes = module.autoencoder_ema.ema_model.decode(latents)
413
- else:
414
- latents = module.autoencoder.encode(encoder_input)
415
-
416
- fakes = module.autoencoder.decode(latents)
417
-
418
- #Interleave reals and fakes
419
- reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n')
420
-
421
- # Put the demos together
422
- reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)')
423
-
424
- log_dict = {}
425
-
426
- filename = f'recon_{trainer.global_step:08}.wav'
427
- reals_fakes = reals_fakes.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
428
- torchaudio.save(filename, reals_fakes, self.sample_rate)
429
-
430
- log_dict[f'recon'] = wandb.Audio(filename,
431
- sample_rate=self.sample_rate,
432
- caption=f'Reconstructed')
433
-
434
- log_dict[f'embeddings_3dpca'] = pca_point_cloud(latents)
435
- log_dict[f'embeddings_spec'] = wandb.Image(tokens_spectrogram_image(latents))
436
-
437
- log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes))
438
-
439
- trainer.logger.experiment.log(log_dict)
440
- except Exception as e:
441
- print(f'{type(e).__name__}: {e}')
442
- raise e
443
- finally:
444
- module.train()
445
-
446
- def create_loss_modules_from_bottleneck(bottleneck, loss_config):
447
- losses = []
448
-
449
- if isinstance(bottleneck, VAEBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck) or isinstance(bottleneck, RVQVAEBottleneck):
450
- try:
451
- kl_weight = loss_config['bottleneck']['weights']['kl']
452
- except:
453
- kl_weight = 1e-6
454
-
455
- kl_loss = ValueLoss(key='kl', weight=kl_weight, name='kl_loss')
456
- losses.append(kl_loss)
457
-
458
- if isinstance(bottleneck, RVQBottleneck) or isinstance(bottleneck, RVQVAEBottleneck):
459
- quantizer_loss = ValueLoss(key='quantizer_loss', weight=1.0, name='quantizer_loss')
460
- losses.append(quantizer_loss)
461
-
462
- if isinstance(bottleneck, DACRVQBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck):
463
- codebook_loss = ValueLoss(key='vq/codebook_loss', weight=1.0, name='codebook_loss')
464
- commitment_loss = ValueLoss(key='vq/commitment_loss', weight=0.25, name='commitment_loss')
465
- losses.append(codebook_loss)
466
- losses.append(commitment_loss)
467
-
468
- if isinstance(bottleneck, WassersteinBottleneck):
469
- try:
470
- mmd_weight = loss_config['bottleneck']['weights']['mmd']
471
- except:
472
- mmd_weight = 100
473
-
474
- mmd_loss = ValueLoss(key='mmd', weight=mmd_weight, name='mmd_loss')
475
- losses.append(mmd_loss)
476
-
477
- return losses
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/training/diffusion.py DELETED
@@ -1,1505 +0,0 @@
1
- import pytorch_lightning as pl
2
- import sys, gc
3
- import random
4
- import torch
5
- import torchaudio
6
- import typing as tp
7
- import wandb
8
-
9
- from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image
10
- import auraloss
11
- from ema_pytorch import EMA
12
- from einops import rearrange
13
- from safetensors.torch import save_file
14
- from torch import optim
15
- from torch.nn import functional as F
16
- from pytorch_lightning.utilities.rank_zero import rank_zero_only
17
-
18
- from ..inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler
19
- from ..models.diffusion import DiffusionModelWrapper, ConditionedDiffusionModelWrapper
20
- from ..models.autoencoders import DiffusionAutoencoder
21
- from ..models.diffusion_prior import PriorType
22
- from .autoencoders import create_loss_modules_from_bottleneck
23
- from .losses import AuralossLoss, MSELoss, MultiLoss
24
- from .utils import create_optimizer_from_config, create_scheduler_from_config
25
-
26
- from time import time
27
-
28
- class Profiler:
29
-
30
- def __init__(self):
31
- self.ticks = [[time(), None]]
32
-
33
- def tick(self, msg):
34
- self.ticks.append([time(), msg])
35
-
36
- def __repr__(self):
37
- rep = 80 * "=" + "\n"
38
- for i in range(1, len(self.ticks)):
39
- msg = self.ticks[i][1]
40
- ellapsed = self.ticks[i][0] - self.ticks[i - 1][0]
41
- rep += msg + f": {ellapsed*1000:.2f}ms\n"
42
- rep += 80 * "=" + "\n\n\n"
43
- return rep
44
-
45
- class DiffusionUncondTrainingWrapper(pl.LightningModule):
46
- '''
47
- Wrapper for training an unconditional audio diffusion model (like Dance Diffusion).
48
- '''
49
- def __init__(
50
- self,
51
- model: DiffusionModelWrapper,
52
- lr: float = 1e-4,
53
- pre_encoded: bool = False
54
- ):
55
- super().__init__()
56
-
57
- self.diffusion = model
58
-
59
- self.diffusion_ema = EMA(
60
- self.diffusion.model,
61
- beta=0.9999,
62
- power=3/4,
63
- update_every=1,
64
- update_after_step=1
65
- )
66
-
67
- self.lr = lr
68
-
69
- self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
70
-
71
- loss_modules = [
72
- MSELoss("v",
73
- "targets",
74
- weight=1.0,
75
- name="mse_loss"
76
- )
77
- ]
78
-
79
- self.losses = MultiLoss(loss_modules)
80
-
81
- self.pre_encoded = pre_encoded
82
-
83
- def configure_optimizers(self):
84
- return optim.Adam([*self.diffusion.parameters()], lr=self.lr)
85
-
86
- def training_step(self, batch, batch_idx):
87
- reals = batch[0]
88
-
89
- if reals.ndim == 4 and reals.shape[0] == 1:
90
- reals = reals[0]
91
-
92
- diffusion_input = reals
93
-
94
- loss_info = {}
95
-
96
- if not self.pre_encoded:
97
- loss_info["audio_reals"] = diffusion_input
98
-
99
- if self.diffusion.pretransform is not None:
100
- if not self.pre_encoded:
101
- with torch.set_grad_enabled(self.diffusion.pretransform.enable_grad):
102
- diffusion_input = self.diffusion.pretransform.encode(diffusion_input)
103
- else:
104
- # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run
105
- if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0:
106
- diffusion_input = diffusion_input / self.diffusion.pretransform.scale
107
-
108
- loss_info["reals"] = diffusion_input
109
-
110
- # Draw uniformly distributed continuous timesteps
111
- t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
112
-
113
- # Calculate the noise schedule parameters for those timesteps
114
- alphas, sigmas = get_alphas_sigmas(t)
115
-
116
- # Combine the ground truth data and the noise
117
- alphas = alphas[:, None, None]
118
- sigmas = sigmas[:, None, None]
119
- noise = torch.randn_like(diffusion_input)
120
- noised_inputs = diffusion_input * alphas + noise * sigmas
121
- targets = noise * alphas - diffusion_input * sigmas
122
-
123
- with torch.cuda.amp.autocast():
124
- v = self.diffusion(noised_inputs, t)
125
-
126
- loss_info.update({
127
- "v": v,
128
- "targets": targets
129
- })
130
-
131
- loss, losses = self.losses(loss_info)
132
-
133
- log_dict = {
134
- 'train/loss': loss.detach(),
135
- 'train/std_data': diffusion_input.std(),
136
- }
137
-
138
- for loss_name, loss_value in losses.items():
139
- log_dict[f"train/{loss_name}"] = loss_value.detach()
140
-
141
- self.log_dict(log_dict, prog_bar=True, on_step=True)
142
- return loss
143
-
144
- def on_before_zero_grad(self, *args, **kwargs):
145
- self.diffusion_ema.update()
146
-
147
- def export_model(self, path, use_safetensors=False):
148
-
149
- self.diffusion.model = self.diffusion_ema.ema_model
150
-
151
- if use_safetensors:
152
- save_file(self.diffusion.state_dict(), path)
153
- else:
154
- torch.save({"state_dict": self.diffusion.state_dict()}, path)
155
-
156
- class DiffusionUncondDemoCallback(pl.Callback):
157
- def __init__(self,
158
- demo_every=2000,
159
- num_demos=8,
160
- demo_steps=250,
161
- sample_rate=48000
162
- ):
163
- super().__init__()
164
-
165
- self.demo_every = demo_every
166
- self.num_demos = num_demos
167
- self.demo_steps = demo_steps
168
- self.sample_rate = sample_rate
169
- self.last_demo_step = -1
170
-
171
- @rank_zero_only
172
- @torch.no_grad()
173
- def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx):
174
-
175
- if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
176
- return
177
-
178
- self.last_demo_step = trainer.global_step
179
-
180
- demo_samples = module.diffusion.sample_size
181
-
182
- if module.diffusion.pretransform is not None:
183
- demo_samples = demo_samples // module.diffusion.pretransform.downsampling_ratio
184
-
185
- noise = torch.randn([self.num_demos, module.diffusion.io_channels, demo_samples]).to(module.device)
186
-
187
- try:
188
- with torch.cuda.amp.autocast():
189
- fakes = sample(module.diffusion_ema, noise, self.demo_steps, 0)
190
-
191
- if module.diffusion.pretransform is not None:
192
- fakes = module.diffusion.pretransform.decode(fakes)
193
-
194
- # Put the demos together
195
- fakes = rearrange(fakes, 'b d n -> d (b n)')
196
-
197
- log_dict = {}
198
-
199
- filename = f'demo_{trainer.global_step:08}.wav'
200
- fakes = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu()
201
- torchaudio.save(filename, fakes, self.sample_rate)
202
-
203
- log_dict[f'demo'] = wandb.Audio(filename,
204
- sample_rate=self.sample_rate,
205
- caption=f'Reconstructed')
206
-
207
- log_dict[f'demo_melspec_left'] = wandb.Image(audio_spectrogram_image(fakes))
208
-
209
- trainer.logger.experiment.log(log_dict)
210
-
211
- del fakes
212
-
213
- except Exception as e:
214
- print(f'{type(e).__name__}: {e}')
215
- finally:
216
- gc.collect()
217
- torch.cuda.empty_cache()
218
-
219
- class DiffusionCondTrainingWrapper(pl.LightningModule):
220
- '''
221
- Wrapper for training a conditional audio diffusion model.
222
- '''
223
- def __init__(
224
- self,
225
- model: ConditionedDiffusionModelWrapper,
226
- lr: float = None,
227
- mask_padding: bool = False,
228
- mask_padding_dropout: float = 0.0,
229
- use_ema: bool = True,
230
- log_loss_info: bool = False,
231
- optimizer_configs: dict = None,
232
- pre_encoded: bool = False,
233
- cfg_dropout_prob = 0.1,
234
- timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform",
235
- ):
236
- super().__init__()
237
-
238
- self.diffusion = model
239
-
240
- if use_ema:
241
- self.diffusion_ema = EMA(
242
- self.diffusion.model,
243
- beta=0.9999,
244
- power=3/4,
245
- update_every=1,
246
- update_after_step=1,
247
- include_online_model=False
248
- )
249
- else:
250
- self.diffusion_ema = None
251
-
252
- self.mask_padding = mask_padding
253
- self.mask_padding_dropout = mask_padding_dropout
254
-
255
- self.cfg_dropout_prob = cfg_dropout_prob
256
-
257
- self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
258
-
259
- self.timestep_sampler = timestep_sampler
260
-
261
- self.diffusion_objective = model.diffusion_objective
262
-
263
- self.loss_modules = [
264
- MSELoss("output",
265
- "targets",
266
- weight=1.0,
267
- mask_key="padding_mask" if self.mask_padding else None,
268
- name="mse_loss"
269
- )
270
- ]
271
-
272
- self.losses = MultiLoss(self.loss_modules)
273
-
274
- self.log_loss_info = log_loss_info
275
-
276
- assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config"
277
-
278
- if optimizer_configs is None:
279
- optimizer_configs = {
280
- "diffusion": {
281
- "optimizer": {
282
- "type": "Adam",
283
- "config": {
284
- "lr": lr
285
- }
286
- }
287
- }
288
- }
289
- else:
290
- if lr is not None:
291
- print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.")
292
-
293
- self.optimizer_configs = optimizer_configs
294
-
295
- self.pre_encoded = pre_encoded
296
-
297
- def configure_optimizers(self):
298
- diffusion_opt_config = self.optimizer_configs['diffusion']
299
- opt_diff = create_optimizer_from_config(diffusion_opt_config['optimizer'], self.diffusion.parameters())
300
-
301
- if "scheduler" in diffusion_opt_config:
302
- sched_diff = create_scheduler_from_config(diffusion_opt_config['scheduler'], opt_diff)
303
- sched_diff_config = {
304
- "scheduler": sched_diff,
305
- "interval": "step"
306
- }
307
- return [opt_diff], [sched_diff_config]
308
-
309
- return [opt_diff]
310
-
311
- def training_step(self, batch, batch_idx):
312
- reals, metadata = batch
313
-
314
- p = Profiler()
315
-
316
- if reals.ndim == 4 and reals.shape[0] == 1:
317
- reals = reals[0]
318
-
319
- loss_info = {}
320
-
321
- diffusion_input = reals
322
-
323
- if not self.pre_encoded:
324
- loss_info["audio_reals"] = diffusion_input
325
-
326
- p.tick("setup")
327
-
328
- with torch.cuda.amp.autocast():
329
- conditioning = self.diffusion.conditioner(metadata, self.device)
330
-
331
- # If mask_padding is on, randomly drop the padding masks to allow for learning silence padding
332
- use_padding_mask = self.mask_padding and random.random() > self.mask_padding_dropout
333
-
334
- # Create batch tensor of attention masks from the "mask" field of the metadata array
335
- if use_padding_mask:
336
- padding_masks = torch.stack([md["padding_mask"][0] for md in metadata], dim=0).to(self.device) # Shape (batch_size, sequence_length)
337
-
338
- p.tick("conditioning")
339
-
340
- if self.diffusion.pretransform is not None:
341
- self.diffusion.pretransform.to(self.device)
342
-
343
- if not self.pre_encoded:
344
- with torch.cuda.amp.autocast() and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad):
345
- diffusion_input = self.diffusion.pretransform.encode(diffusion_input)
346
- p.tick("pretransform")
347
-
348
- # If mask_padding is on, interpolate the padding masks to the size of the pretransformed input
349
- if use_padding_mask:
350
- padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=diffusion_input.shape[2], mode="nearest").squeeze(1).bool()
351
- else:
352
- # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run
353
- if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0:
354
- diffusion_input = diffusion_input / self.diffusion.pretransform.scale
355
-
356
- if self.timestep_sampler == "uniform":
357
- # Draw uniformly distributed continuous timesteps
358
- t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
359
- elif self.timestep_sampler == "logit_normal":
360
- t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device))
361
-
362
- # Calculate the noise schedule parameters for those timesteps
363
- if self.diffusion_objective == "v":
364
- alphas, sigmas = get_alphas_sigmas(t)
365
- elif self.diffusion_objective == "rectified_flow":
366
- alphas, sigmas = 1-t, t
367
-
368
- # Combine the ground truth data and the noise
369
- alphas = alphas[:, None, None]
370
- sigmas = sigmas[:, None, None]
371
- noise = torch.randn_like(diffusion_input)
372
- noised_inputs = diffusion_input * alphas + noise * sigmas
373
-
374
- if self.diffusion_objective == "v":
375
- targets = noise * alphas - diffusion_input * sigmas
376
- elif self.diffusion_objective == "rectified_flow":
377
- targets = noise - diffusion_input
378
-
379
- p.tick("noise")
380
-
381
- extra_args = {}
382
-
383
- if use_padding_mask:
384
- extra_args["mask"] = padding_masks
385
-
386
- with torch.cuda.amp.autocast():
387
- p.tick("amp")
388
- output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = self.cfg_dropout_prob, **extra_args)
389
- p.tick("diffusion")
390
-
391
- loss_info.update({
392
- "output": output,
393
- "targets": targets,
394
- "padding_mask": padding_masks if use_padding_mask else None,
395
- })
396
-
397
- loss, losses = self.losses(loss_info)
398
-
399
- p.tick("loss")
400
-
401
- if self.log_loss_info:
402
- # Loss debugging logs
403
- num_loss_buckets = 10
404
- bucket_size = 1 / num_loss_buckets
405
- loss_all = F.mse_loss(output, targets, reduction="none")
406
-
407
- sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze()
408
-
409
- # gather loss_all across all GPUs
410
- loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n")
411
-
412
- # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size
413
- loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)])
414
-
415
- # Log bucketed losses with corresponding sigma bucket values, if it's not NaN
416
- debug_log_dict = {
417
- f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i])
418
- }
419
-
420
- self.log_dict(debug_log_dict)
421
-
422
-
423
- log_dict = {
424
- 'train/loss': loss.detach(),
425
- 'train/std_data': diffusion_input.std(),
426
- 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr']
427
- }
428
-
429
- for loss_name, loss_value in losses.items():
430
- log_dict[f"train/{loss_name}"] = loss_value.detach()
431
-
432
- self.log_dict(log_dict, prog_bar=True, on_step=True)
433
- p.tick("log")
434
- #print(f"Profiler: {p}")
435
- return loss
436
-
437
- def on_before_zero_grad(self, *args, **kwargs):
438
- if self.diffusion_ema is not None:
439
- self.diffusion_ema.update()
440
-
441
- def export_model(self, path, use_safetensors=False):
442
- if self.diffusion_ema is not None:
443
- self.diffusion.model = self.diffusion_ema.ema_model
444
-
445
- if use_safetensors:
446
- save_file(self.diffusion.state_dict(), path)
447
- else:
448
- torch.save({"state_dict": self.diffusion.state_dict()}, path)
449
-
450
- class DiffusionCondDemoCallback(pl.Callback):
451
- def __init__(self,
452
- demo_every=2000,
453
- num_demos=8,
454
- sample_size=65536,
455
- demo_steps=250,
456
- sample_rate=48000,
457
- demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = {},
458
- demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7],
459
- demo_cond_from_batch: bool = False,
460
- display_audio_cond: bool = False
461
- ):
462
- super().__init__()
463
-
464
- self.demo_every = demo_every
465
- self.num_demos = num_demos
466
- self.demo_samples = sample_size
467
- self.demo_steps = demo_steps
468
- self.sample_rate = sample_rate
469
- self.last_demo_step = -1
470
- self.demo_conditioning = demo_conditioning
471
- self.demo_cfg_scales = demo_cfg_scales
472
-
473
- # If true, the callback will use the metadata from the batch to generate the demo conditioning
474
- self.demo_cond_from_batch = demo_cond_from_batch
475
-
476
- # If true, the callback will display the audio conditioning
477
- self.display_audio_cond = display_audio_cond
478
-
479
- @rank_zero_only
480
- @torch.no_grad()
481
- def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outputs, batch, batch_idx):
482
-
483
- if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
484
- return
485
-
486
- module.eval()
487
-
488
- print(f"Generating demo")
489
- self.last_demo_step = trainer.global_step
490
-
491
- demo_samples = self.demo_samples
492
-
493
- demo_cond = self.demo_conditioning
494
-
495
- if self.demo_cond_from_batch:
496
- # Get metadata from the batch
497
- demo_cond = batch[1][:self.num_demos]
498
-
499
- if module.diffusion.pretransform is not None:
500
- demo_samples = demo_samples // module.diffusion.pretransform.downsampling_ratio
501
-
502
- noise = torch.randn([self.num_demos, module.diffusion.io_channels, demo_samples]).to(module.device)
503
-
504
- try:
505
- print("Getting conditioning")
506
- with torch.cuda.amp.autocast():
507
- conditioning = module.diffusion.conditioner(demo_cond, module.device)
508
-
509
- cond_inputs = module.diffusion.get_conditioning_inputs(conditioning)
510
-
511
- log_dict = {}
512
-
513
- if self.display_audio_cond:
514
- audio_inputs = torch.cat([cond["audio"] for cond in demo_cond], dim=0)
515
- audio_inputs = rearrange(audio_inputs, 'b d n -> d (b n)')
516
-
517
- filename = f'demo_audio_cond_{trainer.global_step:08}.wav'
518
- audio_inputs = audio_inputs.to(torch.float32).mul(32767).to(torch.int16).cpu()
519
- torchaudio.save(filename, audio_inputs, self.sample_rate)
520
- log_dict[f'demo_audio_cond'] = wandb.Audio(filename, sample_rate=self.sample_rate, caption="Audio conditioning")
521
- log_dict[f"demo_audio_cond_melspec_left"] = wandb.Image(audio_spectrogram_image(audio_inputs))
522
- trainer.logger.experiment.log(log_dict)
523
-
524
- for cfg_scale in self.demo_cfg_scales:
525
-
526
- print(f"Generating demo for cfg scale {cfg_scale}")
527
-
528
- with torch.cuda.amp.autocast():
529
- model = module.diffusion_ema.model if module.diffusion_ema is not None else module.diffusion.model
530
-
531
- if module.diffusion_objective == "v":
532
- fakes = sample(model, noise, self.demo_steps, 0, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True)
533
- elif module.diffusion_objective == "rectified_flow":
534
- fakes = sample_discrete_euler(model, noise, self.demo_steps, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True)
535
-
536
- if module.diffusion.pretransform is not None:
537
- fakes = module.diffusion.pretransform.decode(fakes)
538
-
539
- # Put the demos together
540
- fakes = rearrange(fakes, 'b d n -> d (b n)')
541
-
542
- log_dict = {}
543
-
544
- filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav'
545
- fakes = fakes.div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu()
546
- torchaudio.save(filename, fakes, self.sample_rate)
547
-
548
- log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename,
549
- sample_rate=self.sample_rate,
550
- caption=f'Reconstructed')
551
-
552
- log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes))
553
-
554
- trainer.logger.experiment.log(log_dict)
555
-
556
- del fakes
557
-
558
- except Exception as e:
559
- raise e
560
- finally:
561
- gc.collect()
562
- torch.cuda.empty_cache()
563
- module.train()
564
-
565
- class DiffusionCondInpaintTrainingWrapper(pl.LightningModule):
566
- '''
567
- Wrapper for training a conditional audio diffusion model.
568
- '''
569
- def __init__(
570
- self,
571
- model: ConditionedDiffusionModelWrapper,
572
- lr: float = 1e-4,
573
- max_mask_segments = 10,
574
- log_loss_info: bool = False,
575
- optimizer_configs: dict = None,
576
- use_ema: bool = True,
577
- pre_encoded: bool = False,
578
- cfg_dropout_prob = 0.1,
579
- timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform",
580
- ):
581
- super().__init__()
582
-
583
- self.diffusion = model
584
-
585
- self.use_ema = use_ema
586
-
587
- if self.use_ema:
588
- self.diffusion_ema = EMA(
589
- self.diffusion.model,
590
- beta=0.9999,
591
- power=3/4,
592
- update_every=1,
593
- update_after_step=1,
594
- include_online_model=False
595
- )
596
- else:
597
- self.diffusion_ema = None
598
-
599
- self.cfg_dropout_prob = cfg_dropout_prob
600
-
601
- self.lr = lr
602
- self.max_mask_segments = max_mask_segments
603
-
604
- self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
605
-
606
- self.timestep_sampler = timestep_sampler
607
-
608
- self.diffusion_objective = model.diffusion_objective
609
-
610
- self.loss_modules = [
611
- MSELoss("output",
612
- "targets",
613
- weight=1.0,
614
- name="mse_loss"
615
- )
616
- ]
617
-
618
- self.losses = MultiLoss(self.loss_modules)
619
-
620
- self.log_loss_info = log_loss_info
621
-
622
- assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config"
623
-
624
- if optimizer_configs is None:
625
- optimizer_configs = {
626
- "diffusion": {
627
- "optimizer": {
628
- "type": "Adam",
629
- "config": {
630
- "lr": lr
631
- }
632
- }
633
- }
634
- }
635
- else:
636
- if lr is not None:
637
- print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.")
638
-
639
- self.optimizer_configs = optimizer_configs
640
-
641
- self.pre_encoded = pre_encoded
642
-
643
- def configure_optimizers(self):
644
- diffusion_opt_config = self.optimizer_configs['diffusion']
645
- opt_diff = create_optimizer_from_config(diffusion_opt_config['optimizer'], self.diffusion.parameters())
646
-
647
- if "scheduler" in diffusion_opt_config:
648
- sched_diff = create_scheduler_from_config(diffusion_opt_config['scheduler'], opt_diff)
649
- sched_diff_config = {
650
- "scheduler": sched_diff,
651
- "interval": "step"
652
- }
653
- return [opt_diff], [sched_diff_config]
654
-
655
- return [opt_diff]
656
-
657
- def random_mask(self, sequence, max_mask_length):
658
- b, _, sequence_length = sequence.size()
659
-
660
- # Create a mask tensor for each batch element
661
- masks = []
662
-
663
- for i in range(b):
664
- mask_type = random.randint(0, 2)
665
-
666
- if mask_type == 0: # Random mask with multiple segments
667
- num_segments = random.randint(1, self.max_mask_segments)
668
- max_segment_length = max_mask_length // num_segments
669
-
670
- segment_lengths = random.sample(range(1, max_segment_length + 1), num_segments)
671
-
672
- mask = torch.ones((1, 1, sequence_length))
673
- for length in segment_lengths:
674
- mask_start = random.randint(0, sequence_length - length)
675
- mask[:, :, mask_start:mask_start + length] = 0
676
-
677
- elif mask_type == 1: # Full mask
678
- mask = torch.zeros((1, 1, sequence_length))
679
-
680
- elif mask_type == 2: # Causal mask
681
- mask = torch.ones((1, 1, sequence_length))
682
- mask_length = random.randint(1, max_mask_length)
683
- mask[:, :, -mask_length:] = 0
684
-
685
- mask = mask.to(sequence.device)
686
- masks.append(mask)
687
-
688
- # Concatenate the mask tensors into a single tensor
689
- mask = torch.cat(masks, dim=0).to(sequence.device)
690
-
691
- # Apply the mask to the sequence tensor for each batch element
692
- masked_sequence = sequence * mask
693
-
694
- return masked_sequence, mask
695
-
696
- def training_step(self, batch, batch_idx):
697
- reals, metadata = batch
698
-
699
- p = Profiler()
700
-
701
- if reals.ndim == 4 and reals.shape[0] == 1:
702
- reals = reals[0]
703
-
704
- loss_info = {}
705
-
706
- diffusion_input = reals
707
-
708
- if not self.pre_encoded:
709
- loss_info["audio_reals"] = diffusion_input
710
-
711
- p.tick("setup")
712
-
713
- with torch.cuda.amp.autocast():
714
- conditioning = self.diffusion.conditioner(metadata, self.device)
715
-
716
- p.tick("conditioning")
717
-
718
- if self.diffusion.pretransform is not None:
719
- self.diffusion.pretransform.to(self.device)
720
-
721
- if not self.pre_encoded:
722
- with torch.cuda.amp.autocast() and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad):
723
- diffusion_input = self.diffusion.pretransform.encode(diffusion_input)
724
- p.tick("pretransform")
725
-
726
- # If mask_padding is on, interpolate the padding masks to the size of the pretransformed input
727
- # if use_padding_mask:
728
- # padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=diffusion_input.shape[2], mode="nearest").squeeze(1).bool()
729
- else:
730
- # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run
731
- if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0:
732
- diffusion_input = diffusion_input / self.diffusion.pretransform.scale
733
-
734
- # Max mask size is the full sequence length
735
- max_mask_length = diffusion_input.shape[2]
736
-
737
- # Create a mask of random length for a random slice of the input
738
- masked_input, mask = self.random_mask(diffusion_input, max_mask_length)
739
-
740
- conditioning['inpaint_mask'] = [mask]
741
- conditioning['inpaint_masked_input'] = [masked_input]
742
-
743
- if self.timestep_sampler == "uniform":
744
- # Draw uniformly distributed continuous timesteps
745
- t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
746
- elif self.timestep_sampler == "logit_normal":
747
- t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device))
748
-
749
- # Calculate the noise schedule parameters for those timesteps
750
- if self.diffusion_objective == "v":
751
- alphas, sigmas = get_alphas_sigmas(t)
752
- elif self.diffusion_objective == "rectified_flow":
753
- alphas, sigmas = 1-t, t
754
-
755
- # Combine the ground truth data and the noise
756
- alphas = alphas[:, None, None]
757
- sigmas = sigmas[:, None, None]
758
- noise = torch.randn_like(diffusion_input)
759
- noised_inputs = diffusion_input * alphas + noise * sigmas
760
-
761
- if self.diffusion_objective == "v":
762
- targets = noise * alphas - diffusion_input * sigmas
763
- elif self.diffusion_objective == "rectified_flow":
764
- targets = noise - diffusion_input
765
-
766
- p.tick("noise")
767
-
768
- extra_args = {}
769
-
770
- with torch.cuda.amp.autocast():
771
- p.tick("amp")
772
- output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = self.cfg_dropout_prob, **extra_args)
773
- p.tick("diffusion")
774
-
775
- loss_info.update({
776
- "output": output,
777
- "targets": targets,
778
- })
779
-
780
- loss, losses = self.losses(loss_info)
781
-
782
- if self.log_loss_info:
783
- # Loss debugging logs
784
- num_loss_buckets = 10
785
- bucket_size = 1 / num_loss_buckets
786
- loss_all = F.mse_loss(output, targets, reduction="none")
787
-
788
- sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze()
789
-
790
- # gather loss_all across all GPUs
791
- loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n")
792
-
793
- # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size
794
- loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)])
795
-
796
- # Log bucketed losses with corresponding sigma bucket values, if it's not NaN
797
- debug_log_dict = {
798
- f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i])
799
- }
800
-
801
- self.log_dict(debug_log_dict)
802
-
803
- log_dict = {
804
- 'train/loss': loss.detach(),
805
- 'train/std_data': diffusion_input.std(),
806
- 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr']
807
- }
808
-
809
- for loss_name, loss_value in losses.items():
810
- log_dict[f"train/{loss_name}"] = loss_value.detach()
811
-
812
- self.log_dict(log_dict, prog_bar=True, on_step=True)
813
- p.tick("log")
814
- #print(f"Profiler: {p}")
815
- return loss
816
-
817
- def on_before_zero_grad(self, *args, **kwargs):
818
- if self.diffusion_ema is not None:
819
- self.diffusion_ema.update()
820
-
821
- def export_model(self, path, use_safetensors=False):
822
- if self.diffusion_ema is not None:
823
- self.diffusion.model = self.diffusion_ema.ema_model
824
-
825
- if use_safetensors:
826
- save_file(self.diffusion.state_dict(), path)
827
- else:
828
- torch.save({"state_dict": self.diffusion.state_dict()}, path)
829
-
830
- class DiffusionCondInpaintDemoCallback(pl.Callback):
831
- def __init__(
832
- self,
833
- demo_dl,
834
- demo_every=2000,
835
- demo_steps=250,
836
- sample_size=65536,
837
- sample_rate=48000,
838
- demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7]
839
- ):
840
- super().__init__()
841
- self.demo_every = demo_every
842
- self.demo_steps = demo_steps
843
- self.demo_samples = sample_size
844
- self.demo_dl = iter(demo_dl)
845
- self.sample_rate = sample_rate
846
- self.demo_cfg_scales = demo_cfg_scales
847
- self.last_demo_step = -1
848
-
849
- @rank_zero_only
850
- @torch.no_grad()
851
- def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outputs, batch, batch_idx):
852
- if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
853
- return
854
-
855
- self.last_demo_step = trainer.global_step
856
-
857
- try:
858
- log_dict = {}
859
-
860
- demo_reals, metadata = next(self.demo_dl)
861
-
862
- # Remove extra dimension added by WebDataset
863
- if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
864
- demo_reals = demo_reals[0]
865
-
866
- demo_reals = demo_reals.to(module.device)
867
-
868
- if not module.pre_encoded:
869
- # Log the real audio
870
- log_dict[f'demo_reals_melspec_left'] = wandb.Image(audio_spectrogram_image(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu()))
871
- # log_dict[f'demo_reals'] = wandb.Audio(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu(), sample_rate=self.sample_rate, caption="demo reals")
872
-
873
- if module.diffusion.pretransform is not None:
874
- module.diffusion.pretransform.to(module.device)
875
- with torch.cuda.amp.autocast():
876
- demo_reals = module.diffusion.pretransform.encode(demo_reals)
877
-
878
- demo_samples = demo_reals.shape[2]
879
-
880
- # Get conditioning
881
- conditioning = module.diffusion.conditioner(metadata, module.device)
882
-
883
- masked_input, mask = module.random_mask(demo_reals, demo_reals.shape[2])
884
-
885
- conditioning['inpaint_mask'] = [mask]
886
- conditioning['inpaint_masked_input'] = [masked_input]
887
-
888
- if module.diffusion.pretransform is not None:
889
- log_dict[f'demo_masked_input'] = wandb.Image(tokens_spectrogram_image(masked_input.cpu()))
890
- else:
891
- log_dict[f'demo_masked_input'] = wandb.Image(audio_spectrogram_image(rearrange(masked_input, "b c t -> c (b t)").mul(32767).to(torch.int16).cpu()))
892
-
893
- cond_inputs = module.diffusion.get_conditioning_inputs(conditioning)
894
-
895
- noise = torch.randn([demo_reals.shape[0], module.diffusion.io_channels, demo_samples]).to(module.device)
896
-
897
- trainer.logger.experiment.log(log_dict)
898
-
899
- for cfg_scale in self.demo_cfg_scales:
900
- model = module.diffusion_ema.model if module.diffusion_ema is not None else module.diffusion.model
901
- print(f"Generating demo for cfg scale {cfg_scale}")
902
-
903
- if module.diffusion_objective == "v":
904
- fakes = sample(model, noise, self.demo_steps, 0, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True)
905
- elif module.diffusion_objective == "rectified_flow":
906
- fakes = sample_discrete_euler(model, noise, self.demo_steps, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True)
907
-
908
- if module.diffusion.pretransform is not None:
909
- with torch.cuda.amp.autocast():
910
- fakes = module.diffusion.pretransform.decode(fakes)
911
-
912
- # Put the demos together
913
- fakes = rearrange(fakes, 'b d n -> d (b n)')
914
-
915
- log_dict = {}
916
-
917
- filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav'
918
- fakes = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu()
919
- torchaudio.save(filename, fakes, self.sample_rate)
920
-
921
- log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename,
922
- sample_rate=self.sample_rate,
923
- caption=f'Reconstructed')
924
-
925
- log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes))
926
-
927
- trainer.logger.experiment.log(log_dict)
928
- except Exception as e:
929
- print(f'{type(e).__name__}: {e}')
930
- raise e
931
-
932
- class DiffusionAutoencoderTrainingWrapper(pl.LightningModule):
933
- '''
934
- Wrapper for training a diffusion autoencoder
935
- '''
936
- def __init__(
937
- self,
938
- model: DiffusionAutoencoder,
939
- lr: float = 1e-4,
940
- ema_copy = None,
941
- use_reconstruction_loss: bool = False
942
- ):
943
- super().__init__()
944
-
945
- self.diffae = model
946
-
947
- self.diffae_ema = EMA(
948
- self.diffae,
949
- ema_model=ema_copy,
950
- beta=0.9999,
951
- power=3/4,
952
- update_every=1,
953
- update_after_step=1,
954
- include_online_model=False
955
- )
956
-
957
- self.lr = lr
958
-
959
- self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
960
-
961
- loss_modules = [
962
- MSELoss("v",
963
- "targets",
964
- weight=1.0,
965
- name="mse_loss"
966
- )
967
- ]
968
-
969
- if model.bottleneck is not None:
970
- # TODO: Use loss config for configurable bottleneck weights and reconstruction losses
971
- loss_modules += create_loss_modules_from_bottleneck(model.bottleneck, {})
972
-
973
- self.use_reconstruction_loss = use_reconstruction_loss
974
-
975
- if use_reconstruction_loss:
976
- scales = [2048, 1024, 512, 256, 128, 64, 32]
977
- hop_sizes = []
978
- win_lengths = []
979
- overlap = 0.75
980
- for s in scales:
981
- hop_sizes.append(int(s * (1 - overlap)))
982
- win_lengths.append(s)
983
-
984
- sample_rate = model.sample_rate
985
-
986
- stft_loss_args = {
987
- "fft_sizes": scales,
988
- "hop_sizes": hop_sizes,
989
- "win_lengths": win_lengths,
990
- "perceptual_weighting": True
991
- }
992
-
993
- out_channels = model.out_channels
994
-
995
- if model.pretransform is not None:
996
- out_channels = model.pretransform.io_channels
997
-
998
- if out_channels == 2:
999
- self.sdstft = auraloss.freq.SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
1000
- else:
1001
- self.sdstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
1002
-
1003
- loss_modules.append(
1004
- AuralossLoss(self.sdstft, 'audio_reals', 'audio_pred', name='mrstft_loss', weight=0.1), # Reconstruction loss
1005
- )
1006
-
1007
- self.losses = MultiLoss(loss_modules)
1008
-
1009
- def configure_optimizers(self):
1010
- return optim.Adam([*self.diffae.parameters()], lr=self.lr)
1011
-
1012
- def training_step(self, batch, batch_idx):
1013
- reals = batch[0]
1014
-
1015
- if reals.ndim == 4 and reals.shape[0] == 1:
1016
- reals = reals[0]
1017
-
1018
- loss_info = {}
1019
-
1020
- loss_info["audio_reals"] = reals
1021
-
1022
- if self.diffae.pretransform is not None:
1023
- with torch.no_grad():
1024
- reals = self.diffae.pretransform.encode(reals)
1025
-
1026
- loss_info["reals"] = reals
1027
-
1028
- #Encode reals, skipping the pretransform since it was already applied
1029
- latents, encoder_info = self.diffae.encode(reals, return_info=True, skip_pretransform=True)
1030
-
1031
- loss_info["latents"] = latents
1032
- loss_info.update(encoder_info)
1033
-
1034
- if self.diffae.decoder is not None:
1035
- latents = self.diffae.decoder(latents)
1036
-
1037
- # Upsample latents to match diffusion length
1038
- if latents.shape[2] != reals.shape[2]:
1039
- latents = F.interpolate(latents, size=reals.shape[2], mode='nearest')
1040
-
1041
- loss_info["latents_upsampled"] = latents
1042
-
1043
- # Draw uniformly distributed continuous timesteps
1044
- t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
1045
-
1046
- # Calculate the noise schedule parameters for those timesteps
1047
- alphas, sigmas = get_alphas_sigmas(t)
1048
-
1049
- # Combine the ground truth data and the noise
1050
- alphas = alphas[:, None, None]
1051
- sigmas = sigmas[:, None, None]
1052
- noise = torch.randn_like(reals)
1053
- noised_reals = reals * alphas + noise * sigmas
1054
- targets = noise * alphas - reals * sigmas
1055
-
1056
- with torch.cuda.amp.autocast():
1057
- v = self.diffae.diffusion(noised_reals, t, input_concat_cond=latents)
1058
-
1059
- loss_info.update({
1060
- "v": v,
1061
- "targets": targets
1062
- })
1063
-
1064
- if self.use_reconstruction_loss:
1065
- pred = noised_reals * alphas - v * sigmas
1066
-
1067
- loss_info["pred"] = pred
1068
-
1069
- if self.diffae.pretransform is not None:
1070
- pred = self.diffae.pretransform.decode(pred)
1071
- loss_info["audio_pred"] = pred
1072
-
1073
- loss, losses = self.losses(loss_info)
1074
-
1075
- log_dict = {
1076
- 'train/loss': loss.detach(),
1077
- 'train/std_data': reals.std(),
1078
- 'train/latent_std': latents.std(),
1079
- }
1080
-
1081
- for loss_name, loss_value in losses.items():
1082
- log_dict[f"train/{loss_name}"] = loss_value.detach()
1083
-
1084
- self.log_dict(log_dict, prog_bar=True, on_step=True)
1085
- return loss
1086
-
1087
- def on_before_zero_grad(self, *args, **kwargs):
1088
- self.diffae_ema.update()
1089
-
1090
- def export_model(self, path, use_safetensors=False):
1091
-
1092
- model = self.diffae_ema.ema_model
1093
-
1094
- if use_safetensors:
1095
- save_file(model.state_dict(), path)
1096
- else:
1097
- torch.save({"state_dict": model.state_dict()}, path)
1098
-
1099
- class DiffusionAutoencoderDemoCallback(pl.Callback):
1100
- def __init__(
1101
- self,
1102
- demo_dl,
1103
- demo_every=2000,
1104
- demo_steps=250,
1105
- sample_size=65536,
1106
- sample_rate=48000
1107
- ):
1108
- super().__init__()
1109
- self.demo_every = demo_every
1110
- self.demo_steps = demo_steps
1111
- self.demo_samples = sample_size
1112
- self.demo_dl = iter(demo_dl)
1113
- self.sample_rate = sample_rate
1114
- self.last_demo_step = -1
1115
-
1116
- @rank_zero_only
1117
- @torch.no_grad()
1118
- def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrapper, outputs, batch, batch_idx):
1119
- if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
1120
- return
1121
-
1122
- self.last_demo_step = trainer.global_step
1123
-
1124
- demo_reals, _ = next(self.demo_dl)
1125
-
1126
- # Remove extra dimension added by WebDataset
1127
- if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
1128
- demo_reals = demo_reals[0]
1129
-
1130
- encoder_input = demo_reals
1131
-
1132
- encoder_input = encoder_input.to(module.device)
1133
-
1134
- demo_reals = demo_reals.to(module.device)
1135
-
1136
- with torch.no_grad() and torch.cuda.amp.autocast():
1137
- latents = module.diffae_ema.ema_model.encode(encoder_input).float()
1138
- fakes = module.diffae_ema.ema_model.decode(latents, steps=self.demo_steps)
1139
-
1140
- #Interleave reals and fakes
1141
- reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n')
1142
-
1143
- # Put the demos together
1144
- reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)')
1145
-
1146
- log_dict = {}
1147
-
1148
- filename = f'recon_{trainer.global_step:08}.wav'
1149
- reals_fakes = reals_fakes.to(torch.float32).div(torch.max(torch.abs(reals_fakes))).mul(32767).to(torch.int16).cpu()
1150
- torchaudio.save(filename, reals_fakes, self.sample_rate)
1151
-
1152
- log_dict[f'recon'] = wandb.Audio(filename,
1153
- sample_rate=self.sample_rate,
1154
- caption=f'Reconstructed')
1155
-
1156
- log_dict[f'embeddings_3dpca'] = pca_point_cloud(latents)
1157
- log_dict[f'embeddings_spec'] = wandb.Image(tokens_spectrogram_image(latents))
1158
-
1159
- log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes))
1160
-
1161
- if module.diffae_ema.ema_model.pretransform is not None:
1162
- with torch.no_grad() and torch.cuda.amp.autocast():
1163
- initial_latents = module.diffae_ema.ema_model.pretransform.encode(encoder_input)
1164
- first_stage_fakes = module.diffae_ema.ema_model.pretransform.decode(initial_latents)
1165
- first_stage_fakes = rearrange(first_stage_fakes, 'b d n -> d (b n)')
1166
- first_stage_fakes = first_stage_fakes.to(torch.float32).mul(32767).to(torch.int16).cpu()
1167
- first_stage_filename = f'first_stage_{trainer.global_step:08}.wav'
1168
- torchaudio.save(first_stage_filename, first_stage_fakes, self.sample_rate)
1169
-
1170
- log_dict[f'first_stage_latents'] = wandb.Image(tokens_spectrogram_image(initial_latents))
1171
-
1172
- log_dict[f'first_stage'] = wandb.Audio(first_stage_filename,
1173
- sample_rate=self.sample_rate,
1174
- caption=f'First Stage Reconstructed')
1175
-
1176
- log_dict[f'first_stage_melspec_left'] = wandb.Image(audio_spectrogram_image(first_stage_fakes))
1177
-
1178
-
1179
- trainer.logger.experiment.log(log_dict)
1180
-
1181
- def create_source_mixture(reals, num_sources=2):
1182
- # Create a fake mixture source by mixing elements from the training batch together with random offsets
1183
- source = torch.zeros_like(reals)
1184
- for i in range(reals.shape[0]):
1185
- sources_added = 0
1186
-
1187
- js = list(range(reals.shape[0]))
1188
- random.shuffle(js)
1189
- for j in js:
1190
- if i == j or (i != j and sources_added < num_sources):
1191
- # Randomly offset the mixed element between 0 and the length of the source
1192
- seq_len = reals.shape[2]
1193
- offset = random.randint(0, seq_len-1)
1194
- source[i, :, offset:] += reals[j, :, :-offset]
1195
- if i == j:
1196
- # If this is the real one, shift the reals as well to ensure alignment
1197
- new_reals = torch.zeros_like(reals[i])
1198
- new_reals[:, offset:] = reals[i, :, :-offset]
1199
- reals[i] = new_reals
1200
- sources_added += 1
1201
-
1202
- return source
1203
-
1204
- class DiffusionPriorTrainingWrapper(pl.LightningModule):
1205
- '''
1206
- Wrapper for training a diffusion prior for inverse problems
1207
- Prior types:
1208
- mono_stereo: The prior is conditioned on a mono version of the audio to generate a stereo version
1209
- '''
1210
- def __init__(
1211
- self,
1212
- model: ConditionedDiffusionModelWrapper,
1213
- lr: float = 1e-4,
1214
- ema_copy = None,
1215
- prior_type: PriorType = PriorType.MonoToStereo,
1216
- use_reconstruction_loss: bool = False,
1217
- log_loss_info: bool = False,
1218
- ):
1219
- super().__init__()
1220
-
1221
- self.diffusion = model
1222
-
1223
- self.diffusion_ema = EMA(
1224
- self.diffusion,
1225
- ema_model=ema_copy,
1226
- beta=0.9999,
1227
- power=3/4,
1228
- update_every=1,
1229
- update_after_step=1,
1230
- include_online_model=False
1231
- )
1232
-
1233
- self.lr = lr
1234
-
1235
- self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
1236
-
1237
- self.log_loss_info = log_loss_info
1238
-
1239
- loss_modules = [
1240
- MSELoss("v",
1241
- "targets",
1242
- weight=1.0,
1243
- name="mse_loss"
1244
- )
1245
- ]
1246
-
1247
- self.use_reconstruction_loss = use_reconstruction_loss
1248
-
1249
- if use_reconstruction_loss:
1250
- scales = [2048, 1024, 512, 256, 128, 64, 32]
1251
- hop_sizes = []
1252
- win_lengths = []
1253
- overlap = 0.75
1254
- for s in scales:
1255
- hop_sizes.append(int(s * (1 - overlap)))
1256
- win_lengths.append(s)
1257
-
1258
- sample_rate = model.sample_rate
1259
-
1260
- stft_loss_args = {
1261
- "fft_sizes": scales,
1262
- "hop_sizes": hop_sizes,
1263
- "win_lengths": win_lengths,
1264
- "perceptual_weighting": True
1265
- }
1266
-
1267
- out_channels = model.io_channels
1268
-
1269
- self.audio_out_channels = out_channels
1270
-
1271
- if model.pretransform is not None:
1272
- out_channels = model.pretransform.io_channels
1273
-
1274
- if self.audio_out_channels == 2:
1275
- self.sdstft = auraloss.freq.SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
1276
- self.lrstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
1277
-
1278
- # Add left and right channel reconstruction losses in addition to the sum and difference
1279
- self.loss_modules += [
1280
- AuralossLoss(self.lrstft, 'audio_reals_left', 'pred_left', name='stft_loss_left', weight=0.05),
1281
- AuralossLoss(self.lrstft, 'audio_reals_right', 'pred_right', name='stft_loss_right', weight=0.05),
1282
- ]
1283
-
1284
- else:
1285
- self.sdstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
1286
-
1287
- self.loss_modules.append(
1288
- AuralossLoss(self.sdstft, 'audio_reals', 'audio_pred', name='mrstft_loss', weight=0.1), # Reconstruction loss
1289
- )
1290
-
1291
- self.losses = MultiLoss(loss_modules)
1292
-
1293
- self.prior_type = prior_type
1294
-
1295
- def configure_optimizers(self):
1296
- return optim.Adam([*self.diffusion.parameters()], lr=self.lr)
1297
-
1298
- def training_step(self, batch, batch_idx):
1299
- reals, metadata = batch
1300
-
1301
- if reals.ndim == 4 and reals.shape[0] == 1:
1302
- reals = reals[0]
1303
-
1304
- loss_info = {}
1305
-
1306
- loss_info["audio_reals"] = reals
1307
-
1308
- if self.prior_type == PriorType.MonoToStereo:
1309
- source = reals.mean(dim=1, keepdim=True).repeat(1, reals.shape[1], 1).to(self.device)
1310
- loss_info["audio_reals_mono"] = source
1311
- else:
1312
- raise ValueError(f"Unknown prior type {self.prior_type}")
1313
-
1314
- if self.diffusion.pretransform is not None:
1315
- with torch.no_grad():
1316
- reals = self.diffusion.pretransform.encode(reals)
1317
-
1318
- if self.prior_type in [PriorType.MonoToStereo]:
1319
- source = self.diffusion.pretransform.encode(source)
1320
-
1321
- if self.diffusion.conditioner is not None:
1322
- with torch.cuda.amp.autocast():
1323
- conditioning = self.diffusion.conditioner(metadata, self.device)
1324
- else:
1325
- conditioning = {}
1326
-
1327
- loss_info["reals"] = reals
1328
-
1329
- # Draw uniformly distributed continuous timesteps
1330
- t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
1331
-
1332
- # Calculate the noise schedule parameters for those timesteps
1333
- alphas, sigmas = get_alphas_sigmas(t)
1334
-
1335
- # Combine the ground truth data and the noise
1336
- alphas = alphas[:, None, None]
1337
- sigmas = sigmas[:, None, None]
1338
- noise = torch.randn_like(reals)
1339
- noised_reals = reals * alphas + noise * sigmas
1340
- targets = noise * alphas - reals * sigmas
1341
-
1342
- with torch.cuda.amp.autocast():
1343
-
1344
- conditioning['source'] = [source]
1345
-
1346
- v = self.diffusion(noised_reals, t, cond=conditioning, cfg_dropout_prob = 0.1)
1347
-
1348
- loss_info.update({
1349
- "v": v,
1350
- "targets": targets
1351
- })
1352
-
1353
- if self.use_reconstruction_loss:
1354
- pred = noised_reals * alphas - v * sigmas
1355
-
1356
- loss_info["pred"] = pred
1357
-
1358
- if self.diffusion.pretransform is not None:
1359
- pred = self.diffusion.pretransform.decode(pred)
1360
- loss_info["audio_pred"] = pred
1361
-
1362
- if self.audio_out_channels == 2:
1363
- loss_info["pred_left"] = pred[:, 0:1, :]
1364
- loss_info["pred_right"] = pred[:, 1:2, :]
1365
- loss_info["audio_reals_left"] = loss_info["audio_reals"][:, 0:1, :]
1366
- loss_info["audio_reals_right"] = loss_info["audio_reals"][:, 1:2, :]
1367
-
1368
- loss, losses = self.losses(loss_info)
1369
-
1370
- if self.log_loss_info:
1371
- # Loss debugging logs
1372
- num_loss_buckets = 10
1373
- bucket_size = 1 / num_loss_buckets
1374
- loss_all = F.mse_loss(v, targets, reduction="none")
1375
-
1376
- sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze()
1377
-
1378
- # gather loss_all across all GPUs
1379
- loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n")
1380
-
1381
- # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size
1382
- loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)])
1383
-
1384
- # Log bucketed losses with corresponding sigma bucket values, if it's not NaN
1385
- debug_log_dict = {
1386
- f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i])
1387
- }
1388
-
1389
- self.log_dict(debug_log_dict)
1390
-
1391
- log_dict = {
1392
- 'train/loss': loss.detach(),
1393
- 'train/std_data': reals.std()
1394
- }
1395
-
1396
- for loss_name, loss_value in losses.items():
1397
- log_dict[f"train/{loss_name}"] = loss_value.detach()
1398
-
1399
- self.log_dict(log_dict, prog_bar=True, on_step=True)
1400
- return loss
1401
-
1402
- def on_before_zero_grad(self, *args, **kwargs):
1403
- self.diffusion_ema.update()
1404
-
1405
- def export_model(self, path, use_safetensors=False):
1406
-
1407
- #model = self.diffusion_ema.ema_model
1408
- model = self.diffusion
1409
-
1410
- if use_safetensors:
1411
- save_file(model.state_dict(), path)
1412
- else:
1413
- torch.save({"state_dict": model.state_dict()}, path)
1414
-
1415
- class DiffusionPriorDemoCallback(pl.Callback):
1416
- def __init__(
1417
- self,
1418
- demo_dl,
1419
- demo_every=2000,
1420
- demo_steps=250,
1421
- sample_size=65536,
1422
- sample_rate=48000
1423
- ):
1424
- super().__init__()
1425
- self.demo_every = demo_every
1426
- self.demo_steps = demo_steps
1427
- self.demo_samples = sample_size
1428
- self.demo_dl = iter(demo_dl)
1429
- self.sample_rate = sample_rate
1430
- self.last_demo_step = -1
1431
-
1432
- @rank_zero_only
1433
- @torch.no_grad()
1434
- def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrapper, outputs, batch, batch_idx):
1435
- if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
1436
- return
1437
-
1438
- self.last_demo_step = trainer.global_step
1439
-
1440
- demo_reals, metadata = next(self.demo_dl)
1441
-
1442
- # Remove extra dimension added by WebDataset
1443
- if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
1444
- demo_reals = demo_reals[0]
1445
-
1446
- demo_reals = demo_reals.to(module.device)
1447
-
1448
- encoder_input = demo_reals
1449
-
1450
- if module.diffusion.conditioner is not None:
1451
- with torch.cuda.amp.autocast():
1452
- conditioning_tensors = module.diffusion.conditioner(metadata, module.device)
1453
-
1454
- else:
1455
- conditioning_tensors = {}
1456
-
1457
-
1458
- with torch.no_grad() and torch.cuda.amp.autocast():
1459
- if module.prior_type == PriorType.MonoToStereo and encoder_input.shape[1] > 1:
1460
- source = encoder_input.mean(dim=1, keepdim=True).repeat(1, encoder_input.shape[1], 1).to(module.device)
1461
-
1462
- if module.diffusion.pretransform is not None:
1463
- encoder_input = module.diffusion.pretransform.encode(encoder_input)
1464
- source_input = module.diffusion.pretransform.encode(source)
1465
- else:
1466
- source_input = source
1467
-
1468
- conditioning_tensors['source'] = [source_input]
1469
-
1470
- fakes = sample(module.diffusion_ema.model, torch.randn_like(encoder_input), self.demo_steps, 0, cond=conditioning_tensors)
1471
-
1472
- if module.diffusion.pretransform is not None:
1473
- fakes = module.diffusion.pretransform.decode(fakes)
1474
-
1475
- #Interleave reals and fakes
1476
- reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n')
1477
-
1478
- # Put the demos together
1479
- reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)')
1480
-
1481
- log_dict = {}
1482
-
1483
- filename = f'recon_{trainer.global_step:08}.wav'
1484
- reals_fakes = reals_fakes.to(torch.float32).div(torch.max(torch.abs(reals_fakes))).mul(32767).to(torch.int16).cpu()
1485
- torchaudio.save(filename, reals_fakes, self.sample_rate)
1486
-
1487
- log_dict[f'recon'] = wandb.Audio(filename,
1488
- sample_rate=self.sample_rate,
1489
- caption=f'Reconstructed')
1490
-
1491
- log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes))
1492
-
1493
- #Log the source
1494
- filename = f'source_{trainer.global_step:08}.wav'
1495
- source = rearrange(source, 'b d n -> d (b n)')
1496
- source = source.to(torch.float32).mul(32767).to(torch.int16).cpu()
1497
- torchaudio.save(filename, source, self.sample_rate)
1498
-
1499
- log_dict[f'source'] = wandb.Audio(filename,
1500
- sample_rate=self.sample_rate,
1501
- caption=f'Source')
1502
-
1503
- log_dict[f'source_melspec_left'] = wandb.Image(audio_spectrogram_image(source))
1504
-
1505
- trainer.logger.experiment.log(log_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/training/factory.py DELETED
@@ -1,240 +0,0 @@
1
- import torch
2
- from torch.nn import Parameter
3
- from ..models.factory import create_model_from_config
4
-
5
- def create_training_wrapper_from_config(model_config, model):
6
- model_type = model_config.get('model_type', None)
7
- assert model_type is not None, 'model_type must be specified in model config'
8
-
9
- training_config = model_config.get('training', None)
10
- assert training_config is not None, 'training config must be specified in model config'
11
-
12
- if model_type == 'autoencoder':
13
- from .autoencoders import AutoencoderTrainingWrapper
14
-
15
- ema_copy = None
16
-
17
- if training_config.get("use_ema", False):
18
- ema_copy = create_model_from_config(model_config)
19
- ema_copy = create_model_from_config(model_config) # I don't know why this needs to be called twice but it broke when I called it once
20
- # Copy each weight to the ema copy
21
- for name, param in model.state_dict().items():
22
- if isinstance(param, Parameter):
23
- # backwards compatibility for serialized parameters
24
- param = param.data
25
- ema_copy.state_dict()[name].copy_(param)
26
-
27
- use_ema = training_config.get("use_ema", False)
28
-
29
- latent_mask_ratio = training_config.get("latent_mask_ratio", 0.0)
30
-
31
- teacher_model = training_config.get("teacher_model", None)
32
- if teacher_model is not None:
33
- teacher_model = create_model_from_config(teacher_model)
34
- teacher_model = teacher_model.eval().requires_grad_(False)
35
-
36
- teacher_model_ckpt = training_config.get("teacher_model_ckpt", None)
37
- if teacher_model_ckpt is not None:
38
- teacher_model.load_state_dict(torch.load(teacher_model_ckpt)["state_dict"])
39
- else:
40
- raise ValueError("teacher_model_ckpt must be specified if teacher_model is specified")
41
-
42
- return AutoencoderTrainingWrapper(
43
- model,
44
- lr=training_config["learning_rate"],
45
- warmup_steps=training_config.get("warmup_steps", 0),
46
- encoder_freeze_on_warmup=training_config.get("encoder_freeze_on_warmup", False),
47
- sample_rate=model_config["sample_rate"],
48
- loss_config=training_config.get("loss_configs", None),
49
- optimizer_configs=training_config.get("optimizer_configs", None),
50
- use_ema=use_ema,
51
- ema_copy=ema_copy if use_ema else None,
52
- force_input_mono=training_config.get("force_input_mono", False),
53
- latent_mask_ratio=latent_mask_ratio,
54
- teacher_model=teacher_model
55
- )
56
- elif model_type == 'diffusion_uncond':
57
- from .diffusion import DiffusionUncondTrainingWrapper
58
- return DiffusionUncondTrainingWrapper(
59
- model,
60
- lr=training_config["learning_rate"],
61
- pre_encoded=training_config.get("pre_encoded", False),
62
- )
63
- elif model_type == 'diffusion_cond':
64
- from .diffusion import DiffusionCondTrainingWrapper
65
- return DiffusionCondTrainingWrapper(
66
- model,
67
- lr=training_config.get("learning_rate", None),
68
- mask_padding=training_config.get("mask_padding", False),
69
- mask_padding_dropout=training_config.get("mask_padding_dropout", 0.0),
70
- use_ema = training_config.get("use_ema", True),
71
- log_loss_info=training_config.get("log_loss_info", False),
72
- optimizer_configs=training_config.get("optimizer_configs", None),
73
- pre_encoded=training_config.get("pre_encoded", False),
74
- cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1),
75
- timestep_sampler = training_config.get("timestep_sampler", "uniform")
76
- )
77
- elif model_type == 'diffusion_prior':
78
- from .diffusion import DiffusionPriorTrainingWrapper
79
- from ..models.diffusion_prior import PriorType
80
-
81
- ema_copy = create_model_from_config(model_config)
82
-
83
- # Copy each weight to the ema copy
84
- for name, param in model.state_dict().items():
85
- if isinstance(param, Parameter):
86
- # backwards compatibility for serialized parameters
87
- param = param.data
88
- ema_copy.state_dict()[name].copy_(param)
89
-
90
- prior_type = training_config.get("prior_type", "mono_stereo")
91
-
92
- if prior_type == "mono_stereo":
93
- prior_type_enum = PriorType.MonoToStereo
94
- else:
95
- raise ValueError(f"Unknown prior type: {prior_type}")
96
-
97
- return DiffusionPriorTrainingWrapper(
98
- model,
99
- lr=training_config["learning_rate"],
100
- ema_copy=ema_copy,
101
- prior_type=prior_type_enum,
102
- log_loss_info=training_config.get("log_loss_info", False),
103
- use_reconstruction_loss=training_config.get("use_reconstruction_loss", False),
104
- )
105
- elif model_type == 'diffusion_cond_inpaint':
106
- from .diffusion import DiffusionCondInpaintTrainingWrapper
107
- return DiffusionCondInpaintTrainingWrapper(
108
- model,
109
- lr=training_config.get("learning_rate", None),
110
- max_mask_segments = training_config.get("max_mask_segments", 10),
111
- log_loss_info=training_config.get("log_loss_info", False),
112
- optimizer_configs=training_config.get("optimizer_configs", None),
113
- use_ema=training_config.get("use_ema", True),
114
- pre_encoded=training_config.get("pre_encoded", False),
115
- cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1),
116
- timestep_sampler = training_config.get("timestep_sampler", "uniform")
117
- )
118
- elif model_type == 'diffusion_autoencoder':
119
- from .diffusion import DiffusionAutoencoderTrainingWrapper
120
-
121
- ema_copy = create_model_from_config(model_config)
122
-
123
- # Copy each weight to the ema copy
124
- for name, param in model.state_dict().items():
125
- if isinstance(param, Parameter):
126
- # backwards compatibility for serialized parameters
127
- param = param.data
128
- ema_copy.state_dict()[name].copy_(param)
129
-
130
- return DiffusionAutoencoderTrainingWrapper(
131
- model,
132
- ema_copy=ema_copy,
133
- lr=training_config["learning_rate"],
134
- use_reconstruction_loss=training_config.get("use_reconstruction_loss", False)
135
- )
136
- elif model_type == 'lm':
137
- from .lm import AudioLanguageModelTrainingWrapper
138
-
139
- ema_copy = create_model_from_config(model_config)
140
-
141
- for name, param in model.state_dict().items():
142
- if isinstance(param, Parameter):
143
- # backwards compatibility for serialized parameters
144
- param = param.data
145
- ema_copy.state_dict()[name].copy_(param)
146
-
147
- return AudioLanguageModelTrainingWrapper(
148
- model,
149
- ema_copy=ema_copy,
150
- lr=training_config.get("learning_rate", None),
151
- use_ema=training_config.get("use_ema", False),
152
- optimizer_configs=training_config.get("optimizer_configs", None),
153
- pre_encoded=training_config.get("pre_encoded", False),
154
- )
155
-
156
- else:
157
- raise NotImplementedError(f'Unknown model type: {model_type}')
158
-
159
- def create_demo_callback_from_config(model_config, **kwargs):
160
- model_type = model_config.get('model_type', None)
161
- assert model_type is not None, 'model_type must be specified in model config'
162
-
163
- training_config = model_config.get('training', None)
164
- assert training_config is not None, 'training config must be specified in model config'
165
-
166
- demo_config = training_config.get("demo", {})
167
-
168
- if model_type == 'autoencoder':
169
- from .autoencoders import AutoencoderDemoCallback
170
- return AutoencoderDemoCallback(
171
- demo_every=demo_config.get("demo_every", 2000),
172
- sample_size=model_config["sample_size"],
173
- sample_rate=model_config["sample_rate"],
174
- **kwargs
175
- )
176
- elif model_type == 'diffusion_uncond':
177
- from .diffusion import DiffusionUncondDemoCallback
178
- return DiffusionUncondDemoCallback(
179
- demo_every=demo_config.get("demo_every", 2000),
180
- demo_steps=demo_config.get("demo_steps", 250),
181
- sample_rate=model_config["sample_rate"]
182
- )
183
- elif model_type == "diffusion_autoencoder":
184
- from .diffusion import DiffusionAutoencoderDemoCallback
185
- return DiffusionAutoencoderDemoCallback(
186
- demo_every=demo_config.get("demo_every", 2000),
187
- demo_steps=demo_config.get("demo_steps", 250),
188
- sample_size=model_config["sample_size"],
189
- sample_rate=model_config["sample_rate"],
190
- **kwargs
191
- )
192
- elif model_type == "diffusion_prior":
193
- from .diffusion import DiffusionPriorDemoCallback
194
- return DiffusionPriorDemoCallback(
195
- demo_every=demo_config.get("demo_every", 2000),
196
- demo_steps=demo_config.get("demo_steps", 250),
197
- sample_size=model_config["sample_size"],
198
- sample_rate=model_config["sample_rate"],
199
- **kwargs
200
- )
201
- elif model_type == "diffusion_cond":
202
- from .diffusion import DiffusionCondDemoCallback
203
-
204
- return DiffusionCondDemoCallback(
205
- demo_every=demo_config.get("demo_every", 2000),
206
- sample_size=model_config["sample_size"],
207
- sample_rate=model_config["sample_rate"],
208
- demo_steps=demo_config.get("demo_steps", 250),
209
- num_demos=demo_config["num_demos"],
210
- demo_cfg_scales=demo_config["demo_cfg_scales"],
211
- demo_conditioning=demo_config.get("demo_cond", {}),
212
- demo_cond_from_batch=demo_config.get("demo_cond_from_batch", False),
213
- display_audio_cond=demo_config.get("display_audio_cond", False),
214
- )
215
- elif model_type == "diffusion_cond_inpaint":
216
- from .diffusion import DiffusionCondInpaintDemoCallback
217
-
218
- return DiffusionCondInpaintDemoCallback(
219
- demo_every=demo_config.get("demo_every", 2000),
220
- sample_size=model_config["sample_size"],
221
- sample_rate=model_config["sample_rate"],
222
- demo_steps=demo_config.get("demo_steps", 250),
223
- demo_cfg_scales=demo_config["demo_cfg_scales"],
224
- **kwargs
225
- )
226
-
227
- elif model_type == "lm":
228
- from .lm import AudioLanguageModelDemoCallback
229
-
230
- return AudioLanguageModelDemoCallback(
231
- demo_every=demo_config.get("demo_every", 2000),
232
- sample_size=model_config["sample_size"],
233
- sample_rate=model_config["sample_rate"],
234
- demo_cfg_scales=demo_config.get("demo_cfg_scales", [1]),
235
- demo_conditioning=demo_config.get("demo_cond", None),
236
- num_demos=demo_config.get("num_demos", 8),
237
- **kwargs
238
- )
239
- else:
240
- raise NotImplementedError(f'Unknown model type: {model_type}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/training/lm.py DELETED
@@ -1,267 +0,0 @@
1
- import pytorch_lightning as pl
2
- import sys, gc
3
- import random
4
- import torch
5
- import torchaudio
6
- import typing as tp
7
- import wandb
8
-
9
- from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image
10
- from ema_pytorch import EMA
11
- from einops import rearrange
12
- from safetensors.torch import save_file
13
- from torch import optim
14
- from torch.nn import functional as F
15
- from pytorch_lightning.utilities.rank_zero import rank_zero_only
16
-
17
- from ..models.lm import AudioLanguageModelWrapper
18
- from .utils import create_optimizer_from_config, create_scheduler_from_config
19
-
20
- class AudioLanguageModelTrainingWrapper(pl.LightningModule):
21
- def __init__(
22
- self,
23
- model: AudioLanguageModelWrapper,
24
- lr = 1e-4,
25
- use_ema=False,
26
- ema_copy=None,
27
- optimizer_configs: dict = None,
28
- pre_encoded=False
29
- ):
30
- super().__init__()
31
-
32
- self.model = model
33
-
34
- self.model.pretransform.requires_grad_(False)
35
-
36
- self.model_ema = None
37
- if use_ema:
38
- self.model_ema = EMA(self.model, ema_model=ema_copy, beta=0.99, update_every=10)
39
-
40
- assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config"
41
-
42
- if optimizer_configs is None:
43
- optimizer_configs = {
44
- "lm": {
45
- "optimizer": {
46
- "type": "AdamW",
47
- "config": {
48
- "lr": lr,
49
- "betas": (0.9, 0.95),
50
- "weight_decay": 0.1
51
- }
52
- }
53
- }
54
- }
55
- else:
56
- if lr is not None:
57
- print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.")
58
-
59
- self.optimizer_configs = optimizer_configs
60
-
61
- self.pre_encoded = pre_encoded
62
-
63
- def configure_optimizers(self):
64
- lm_opt_config = self.optimizer_configs['lm']
65
- opt_lm = create_optimizer_from_config(lm_opt_config['optimizer'], self.model.parameters())
66
-
67
- if "scheduler" in lm_opt_config:
68
- sched_lm = create_scheduler_from_config(lm_opt_config['scheduler'], opt_lm)
69
- sched_lm_config = {
70
- "scheduler": sched_lm,
71
- "interval": "step"
72
- }
73
- return [opt_lm], [sched_lm_config]
74
-
75
- return [opt_lm]
76
-
77
- # Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/solvers/musicgen.py under MIT license
78
- # License can be found in LICENSES/LICENSE_META.txt
79
-
80
- def _compute_cross_entropy(
81
- self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor
82
- ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]:
83
- """Compute cross entropy between multi-codebook targets and model's logits.
84
- The cross entropy is computed per codebook to provide codebook-level cross entropy.
85
- Valid timesteps for each of the codebook are pulled from the mask, where invalid
86
- timesteps are set to 0.
87
-
88
- Args:
89
- logits (torch.Tensor): Model's logits of shape [B, K, T, card].
90
- targets (torch.Tensor): Target codes, of shape [B, K, T].
91
- mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
92
- Returns:
93
- ce (torch.Tensor): Cross entropy averaged over the codebooks
94
- ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached).
95
- """
96
- B, K, T = targets.shape
97
- assert logits.shape[:-1] == targets.shape
98
- assert mask.shape == targets.shape
99
- ce = torch.zeros([], device=targets.device)
100
- ce_per_codebook: tp.List[torch.Tensor] = []
101
- for k in range(K):
102
- logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card]
103
- targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T]
104
- mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T]
105
- ce_targets = targets_k[mask_k]
106
- ce_logits = logits_k[mask_k]
107
- q_ce = F.cross_entropy(ce_logits, ce_targets)
108
- ce += q_ce
109
- ce_per_codebook.append(q_ce.detach())
110
- # average cross entropy across codebooks
111
- ce = ce / K
112
- return ce, ce_per_codebook
113
-
114
- def training_step(self, batch, batch_idx):
115
- reals, metadata = batch
116
-
117
- if reals.ndim == 4 and reals.shape[0] == 1:
118
- reals = reals[0]
119
-
120
- if not self.pre_encoded:
121
- codes = self.model.pretransform.tokenize(reals)
122
- else:
123
- codes = reals
124
-
125
- padding_masks = []
126
- for md in metadata:
127
- if md["padding_mask"].ndim == 1:
128
- padding_masks.append(md["padding_mask"])
129
- else:
130
- padding_masks.append(md["padding_mask"][0])
131
-
132
- padding_masks = torch.stack(padding_masks, dim=0).to(self.device) # Shape (batch_size, sequence_length)
133
-
134
- # Interpolate padding masks to the same length as the codes
135
- padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=codes.shape[2], mode='nearest').bool()
136
-
137
- condition_tensors = None
138
-
139
- # If the model is conditioned, get the conditioning tensors
140
- if self.model.conditioner is not None:
141
- condition_tensors = self.model.conditioner(metadata, self.device)
142
-
143
- lm_output = self.model.compute_logits(codes, condition_tensors=condition_tensors, cfg_dropout_prob=0.1)
144
-
145
- logits = lm_output.logits # [b, k, t, c]
146
- logits_mask = lm_output.mask # [b, k, t]
147
-
148
- logits_mask = logits_mask & padding_masks
149
-
150
- cross_entropy, cross_entropy_per_codebook = self._compute_cross_entropy(logits, codes, logits_mask)
151
-
152
- loss = cross_entropy
153
-
154
- log_dict = {
155
- 'train/loss': loss.detach(),
156
- 'train/cross_entropy': cross_entropy.detach(),
157
- 'train/perplexity': torch.exp(cross_entropy).detach(),
158
- 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr']
159
- }
160
-
161
- for k, ce_q in enumerate(cross_entropy_per_codebook):
162
- log_dict[f'cross_entropy_q{k + 1}'] = ce_q
163
- log_dict[f'perplexity_q{k + 1}'] = torch.exp(ce_q)
164
-
165
- self.log_dict(log_dict, prog_bar=True, on_step=True)
166
- return loss
167
-
168
- def on_before_zero_grad(self, *args, **kwargs):
169
- if self.model_ema is not None:
170
- self.model_ema.update()
171
-
172
- def export_model(self, path, use_safetensors=False):
173
-
174
- model = self.model_ema.ema_model if self.model_ema is not None else self.model
175
-
176
- if use_safetensors:
177
- save_file(model.state_dict(), path)
178
- else:
179
- torch.save({"state_dict": model.state_dict()}, path)
180
-
181
-
182
- class AudioLanguageModelDemoCallback(pl.Callback):
183
- def __init__(self,
184
- demo_every=2000,
185
- num_demos=8,
186
- sample_size=65536,
187
- sample_rate=48000,
188
- demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None,
189
- demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7],
190
- **kwargs
191
- ):
192
- super().__init__()
193
-
194
- self.demo_every = demo_every
195
- self.num_demos = num_demos
196
- self.demo_samples = sample_size
197
- self.sample_rate = sample_rate
198
- self.last_demo_step = -1
199
- self.demo_conditioning = demo_conditioning
200
- self.demo_cfg_scales = demo_cfg_scales
201
-
202
- @rank_zero_only
203
- @torch.no_grad()
204
- def on_train_batch_end(self, trainer, module: AudioLanguageModelTrainingWrapper, outputs, batch, batch_idx):
205
-
206
- if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
207
- return
208
-
209
- module.eval()
210
-
211
- print(f"Generating demo")
212
- self.last_demo_step = trainer.global_step
213
-
214
- demo_length_tokens = self.demo_samples // module.model.pretransform.downsampling_ratio
215
-
216
- #demo_reals = batch[0][:self.num_demos]
217
-
218
- # if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
219
- # demo_reals = demo_reals[0]
220
-
221
- #demo_reals_tokens = module.model.pretransform.tokenize(demo_reals)
222
-
223
- ##Limit to first 50 tokens
224
- #demo_reals_tokens = demo_reals_tokens[:, :, :50]
225
-
226
- try:
227
- print("Getting conditioning")
228
-
229
- for cfg_scale in self.demo_cfg_scales:
230
-
231
- model = module.model # module.model_ema.ema_model if module.model_ema is not None else module.model
232
-
233
- print(f"Generating demo for cfg scale {cfg_scale}")
234
- fakes = model.generate_audio(
235
- batch_size=self.num_demos,
236
- max_gen_len=demo_length_tokens,
237
- conditioning=self.demo_conditioning,
238
- #init_data = demo_reals_tokens,
239
- cfg_scale=cfg_scale,
240
- temp=1.0,
241
- top_p=0.95
242
- )
243
-
244
- # Put the demos together
245
- fakes = rearrange(fakes, 'b d n -> d (b n)')
246
-
247
- log_dict = {}
248
-
249
- filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav'
250
- fakes = fakes / fakes.abs().max()
251
- fakes = fakes.type(torch.float32).clamp(-1, 1).mul(32767).type(torch.int16).cpu()
252
- torchaudio.save(filename, fakes, self.sample_rate)
253
-
254
- log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename,
255
- sample_rate=self.sample_rate,
256
- caption=f'Reconstructed')
257
-
258
- log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes))
259
-
260
- trainer.logger.experiment.log(log_dict)
261
-
262
- except Exception as e:
263
- raise e
264
- finally:
265
- gc.collect()
266
- torch.cuda.empty_cache()
267
- module.train()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/training/losses/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .losses import *
 
 
stable/build/lib/stable_audio_tools/training/losses/auraloss.py DELETED
@@ -1,607 +0,0 @@
1
- # Copied and modified from https://github.com/csteinmetz1/auraloss/blob/main/auraloss/freq.py under Apache License 2.0
2
- # You can find the license at LICENSES/LICENSE_AURALOSS.txt
3
-
4
- import torch
5
- import numpy as np
6
- from typing import List, Any
7
- import scipy.signal
8
-
9
- def apply_reduction(losses, reduction="none"):
10
- """Apply reduction to collection of losses."""
11
- if reduction == "mean":
12
- losses = losses.mean()
13
- elif reduction == "sum":
14
- losses = losses.sum()
15
- return losses
16
-
17
- def get_window(win_type: str, win_length: int):
18
- """Return a window function.
19
-
20
- Args:
21
- win_type (str): Window type. Can either be one of the window function provided in PyTorch
22
- ['hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window']
23
- or any of the windows provided by [SciPy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.windows.get_window.html).
24
- win_length (int): Window length
25
-
26
- Returns:
27
- win: The window as a 1D torch tensor
28
- """
29
-
30
- try:
31
- win = getattr(torch, win_type)(win_length)
32
- except:
33
- win = torch.from_numpy(scipy.signal.windows.get_window(win_type, win_length))
34
-
35
- return win
36
-
37
- class SumAndDifference(torch.nn.Module):
38
- """Sum and difference signal extraction module."""
39
-
40
- def __init__(self):
41
- """Initialize sum and difference extraction module."""
42
- super(SumAndDifference, self).__init__()
43
-
44
- def forward(self, x):
45
- """Calculate forward propagation.
46
-
47
- Args:
48
- x (Tensor): Predicted signal (B, #channels, #samples).
49
- Returns:
50
- Tensor: Sum signal.
51
- Tensor: Difference signal.
52
- """
53
- if not (x.size(1) == 2): # inputs must be stereo
54
- raise ValueError(f"Input must be stereo: {x.size(1)} channel(s).")
55
-
56
- sum_sig = self.sum(x).unsqueeze(1)
57
- diff_sig = self.diff(x).unsqueeze(1)
58
-
59
- return sum_sig, diff_sig
60
-
61
- @staticmethod
62
- def sum(x):
63
- return x[:, 0, :] + x[:, 1, :]
64
-
65
- @staticmethod
66
- def diff(x):
67
- return x[:, 0, :] - x[:, 1, :]
68
-
69
-
70
- class FIRFilter(torch.nn.Module):
71
- """FIR pre-emphasis filtering module.
72
-
73
- Args:
74
- filter_type (str): Shape of the desired FIR filter ("hp", "fd", "aw"). Default: "hp"
75
- coef (float): Coefficient value for the filter tap (only applicable for "hp" and "fd"). Default: 0.85
76
- ntaps (int): Number of FIR filter taps for constructing A-weighting filters. Default: 101
77
- plot (bool): Plot the magnitude respond of the filter. Default: False
78
-
79
- Based upon the perceptual loss pre-empahsis filters proposed by
80
- [Wright & Välimäki, 2019](https://arxiv.org/abs/1911.08922).
81
-
82
- A-weighting filter - "aw"
83
- First-order highpass - "hp"
84
- Folded differentiator - "fd"
85
-
86
- Note that the default coefficeint value of 0.85 is optimized for
87
- a sampling rate of 44.1 kHz, considering adjusting this value at differnt sampling rates.
88
- """
89
-
90
- def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, plot=False):
91
- """Initilize FIR pre-emphasis filtering module."""
92
- super(FIRFilter, self).__init__()
93
- self.filter_type = filter_type
94
- self.coef = coef
95
- self.fs = fs
96
- self.ntaps = ntaps
97
- self.plot = plot
98
-
99
- import scipy.signal
100
-
101
- if ntaps % 2 == 0:
102
- raise ValueError(f"ntaps must be odd (ntaps={ntaps}).")
103
-
104
- if filter_type == "hp":
105
- self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1)
106
- self.fir.weight.requires_grad = False
107
- self.fir.weight.data = torch.tensor([1, -coef, 0]).view(1, 1, -1)
108
- elif filter_type == "fd":
109
- self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1)
110
- self.fir.weight.requires_grad = False
111
- self.fir.weight.data = torch.tensor([1, 0, -coef]).view(1, 1, -1)
112
- elif filter_type == "aw":
113
- # Definition of analog A-weighting filter according to IEC/CD 1672.
114
- f1 = 20.598997
115
- f2 = 107.65265
116
- f3 = 737.86223
117
- f4 = 12194.217
118
- A1000 = 1.9997
119
-
120
- NUMs = [(2 * np.pi * f4) ** 2 * (10 ** (A1000 / 20)), 0, 0, 0, 0]
121
- DENs = np.polymul(
122
- [1, 4 * np.pi * f4, (2 * np.pi * f4) ** 2],
123
- [1, 4 * np.pi * f1, (2 * np.pi * f1) ** 2],
124
- )
125
- DENs = np.polymul(
126
- np.polymul(DENs, [1, 2 * np.pi * f3]), [1, 2 * np.pi * f2]
127
- )
128
-
129
- # convert analog filter to digital filter
130
- b, a = scipy.signal.bilinear(NUMs, DENs, fs=fs)
131
-
132
- # compute the digital filter frequency response
133
- w_iir, h_iir = scipy.signal.freqz(b, a, worN=512, fs=fs)
134
-
135
- # then we fit to 101 tap FIR filter with least squares
136
- taps = scipy.signal.firls(ntaps, w_iir, abs(h_iir), fs=fs)
137
-
138
- # now implement this digital FIR filter as a Conv1d layer
139
- self.fir = torch.nn.Conv1d(
140
- 1, 1, kernel_size=ntaps, bias=False, padding=ntaps // 2
141
- )
142
- self.fir.weight.requires_grad = False
143
- self.fir.weight.data = torch.tensor(taps.astype("float32")).view(1, 1, -1)
144
-
145
- if plot:
146
- from .plotting import compare_filters
147
- compare_filters(b, a, taps, fs=fs)
148
-
149
- def forward(self, input, target):
150
- """Calculate forward propagation.
151
- Args:
152
- input (Tensor): Predicted signal (B, #channels, #samples).
153
- target (Tensor): Groundtruth signal (B, #channels, #samples).
154
- Returns:
155
- Tensor: Filtered signal.
156
- """
157
- input = torch.nn.functional.conv1d(
158
- input, self.fir.weight.data, padding=self.ntaps // 2
159
- )
160
- target = torch.nn.functional.conv1d(
161
- target, self.fir.weight.data, padding=self.ntaps // 2
162
- )
163
- return input, target
164
-
165
- class SpectralConvergenceLoss(torch.nn.Module):
166
- """Spectral convergence loss module.
167
-
168
- See [Arik et al., 2018](https://arxiv.org/abs/1808.06719).
169
- """
170
-
171
- def __init__(self):
172
- super(SpectralConvergenceLoss, self).__init__()
173
-
174
- def forward(self, x_mag, y_mag):
175
- return (torch.norm(y_mag - x_mag, p="fro", dim=[-1, -2]) / torch.norm(y_mag, p="fro", dim=[-1, -2])).mean()
176
-
177
- class STFTMagnitudeLoss(torch.nn.Module):
178
- """STFT magnitude loss module.
179
-
180
- See [Arik et al., 2018](https://arxiv.org/abs/1808.06719)
181
- and [Engel et al., 2020](https://arxiv.org/abs/2001.04643v1)
182
-
183
- Log-magnitudes are calculated with `log(log_fac*x + log_eps)`, where `log_fac` controls the
184
- compression strength (larger value results in more compression), and `log_eps` can be used
185
- to control the range of the compressed output values (e.g., `log_eps>=1` ensures positive
186
- output values). The default values `log_fac=1` and `log_eps=0` correspond to plain log-compression.
187
-
188
- Args:
189
- log (bool, optional): Log-scale the STFT magnitudes,
190
- or use linear scale. Default: True
191
- log_eps (float, optional): Constant value added to the magnitudes before evaluating the logarithm.
192
- Default: 0.0
193
- log_fac (float, optional): Constant multiplication factor for the magnitudes before evaluating the logarithm.
194
- Default: 1.0
195
- distance (str, optional): Distance function ["L1", "L2"]. Default: "L1"
196
- reduction (str, optional): Reduction of the loss elements. Default: "mean"
197
- """
198
-
199
- def __init__(self, log=True, log_eps=0.0, log_fac=1.0, distance="L1", reduction="mean"):
200
- super(STFTMagnitudeLoss, self).__init__()
201
-
202
- self.log = log
203
- self.log_eps = log_eps
204
- self.log_fac = log_fac
205
-
206
- if distance == "L1":
207
- self.distance = torch.nn.L1Loss(reduction=reduction)
208
- elif distance == "L2":
209
- self.distance = torch.nn.MSELoss(reduction=reduction)
210
- else:
211
- raise ValueError(f"Invalid distance: '{distance}'.")
212
-
213
- def forward(self, x_mag, y_mag):
214
- if self.log:
215
- x_mag = torch.log(self.log_fac * x_mag + self.log_eps)
216
- y_mag = torch.log(self.log_fac * y_mag + self.log_eps)
217
- return self.distance(x_mag, y_mag)
218
-
219
-
220
- class STFTLoss(torch.nn.Module):
221
- """STFT loss module.
222
-
223
- See [Yamamoto et al. 2019](https://arxiv.org/abs/1904.04472).
224
-
225
- Args:
226
- fft_size (int, optional): FFT size in samples. Default: 1024
227
- hop_size (int, optional): Hop size of the FFT in samples. Default: 256
228
- win_length (int, optional): Length of the FFT analysis window. Default: 1024
229
- window (str, optional): Window to apply before FFT, can either be one of the window function provided in PyTorch
230
- ['hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window']
231
- or any of the windows provided by [SciPy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.windows.get_window.html).
232
- Default: 'hann_window'
233
- w_sc (float, optional): Weight of the spectral convergence loss term. Default: 1.0
234
- w_log_mag (float, optional): Weight of the log magnitude loss term. Default: 1.0
235
- w_lin_mag_mag (float, optional): Weight of the linear magnitude loss term. Default: 0.0
236
- w_phs (float, optional): Weight of the spectral phase loss term. Default: 0.0
237
- sample_rate (int, optional): Sample rate. Required when scale = 'mel'. Default: None
238
- scale (str, optional): Optional frequency scaling method, options include:
239
- ['mel', 'chroma']
240
- Default: None
241
- n_bins (int, optional): Number of scaling frequency bins. Default: None.
242
- perceptual_weighting (bool, optional): Apply perceptual A-weighting (Sample rate must be supplied). Default: False
243
- scale_invariance (bool, optional): Perform an optimal scaling of the target. Default: False
244
- eps (float, optional): Small epsilon value for stablity. Default: 1e-8
245
- output (str, optional): Format of the loss returned.
246
- 'loss' : Return only the raw, aggregate loss term.
247
- 'full' : Return the raw loss, plus intermediate loss terms.
248
- Default: 'loss'
249
- reduction (str, optional): Specifies the reduction to apply to the output:
250
- 'none': no reduction will be applied,
251
- 'mean': the sum of the output will be divided by the number of elements in the output,
252
- 'sum': the output will be summed.
253
- Default: 'mean'
254
- mag_distance (str, optional): Distance function ["L1", "L2"] for the magnitude loss terms.
255
- device (str, optional): Place the filterbanks on specified device. Default: None
256
-
257
- Returns:
258
- loss:
259
- Aggreate loss term. Only returned if output='loss'. By default.
260
- loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss:
261
- Aggregate and intermediate loss terms. Only returned if output='full'.
262
- """
263
-
264
- def __init__(
265
- self,
266
- fft_size: int = 1024,
267
- hop_size: int = 256,
268
- win_length: int = 1024,
269
- window: str = "hann_window",
270
- w_sc: float = 1.0,
271
- w_log_mag: float = 1.0,
272
- w_lin_mag: float = 0.0,
273
- w_phs: float = 0.0,
274
- sample_rate: float = None,
275
- scale: str = None,
276
- n_bins: int = None,
277
- perceptual_weighting: bool = False,
278
- scale_invariance: bool = False,
279
- eps: float = 1e-8,
280
- output: str = "loss",
281
- reduction: str = "mean",
282
- mag_distance: str = "L1",
283
- device: Any = None,
284
- **kwargs
285
- ):
286
- super().__init__()
287
- self.fft_size = fft_size
288
- self.hop_size = hop_size
289
- self.win_length = win_length
290
- self.window = get_window(window, win_length)
291
- self.w_sc = w_sc
292
- self.w_log_mag = w_log_mag
293
- self.w_lin_mag = w_lin_mag
294
- self.w_phs = w_phs
295
- self.sample_rate = sample_rate
296
- self.scale = scale
297
- self.n_bins = n_bins
298
- self.perceptual_weighting = perceptual_weighting
299
- self.scale_invariance = scale_invariance
300
- self.eps = eps
301
- self.output = output
302
- self.reduction = reduction
303
- self.mag_distance = mag_distance
304
- self.device = device
305
-
306
- self.phs_used = bool(self.w_phs)
307
-
308
- self.spectralconv = SpectralConvergenceLoss()
309
- self.logstft = STFTMagnitudeLoss(
310
- log=True,
311
- reduction=reduction,
312
- distance=mag_distance,
313
- **kwargs
314
- )
315
- self.linstft = STFTMagnitudeLoss(
316
- log=False,
317
- reduction=reduction,
318
- distance=mag_distance,
319
- **kwargs
320
- )
321
-
322
- # setup mel filterbank
323
- if scale is not None:
324
- try:
325
- import librosa.filters
326
- except Exception as e:
327
- print(e)
328
- print("Try `pip install auraloss[all]`.")
329
-
330
- if self.scale == "mel":
331
- assert sample_rate != None # Must set sample rate to use mel scale
332
- assert n_bins <= fft_size # Must be more FFT bins than Mel bins
333
- fb = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=n_bins)
334
- fb = torch.tensor(fb).unsqueeze(0)
335
-
336
- elif self.scale == "chroma":
337
- assert sample_rate != None # Must set sample rate to use chroma scale
338
- assert n_bins <= fft_size # Must be more FFT bins than chroma bins
339
- fb = librosa.filters.chroma(
340
- sr=sample_rate, n_fft=fft_size, n_chroma=n_bins
341
- )
342
-
343
- else:
344
- raise ValueError(
345
- f"Invalid scale: {self.scale}. Must be 'mel' or 'chroma'."
346
- )
347
-
348
- self.register_buffer("fb", fb)
349
-
350
- if scale is not None and device is not None:
351
- self.fb = self.fb.to(self.device) # move filterbank to device
352
-
353
- if self.perceptual_weighting:
354
- if sample_rate is None:
355
- raise ValueError(
356
- f"`sample_rate` must be supplied when `perceptual_weighting = True`."
357
- )
358
- self.prefilter = FIRFilter(filter_type="aw", fs=sample_rate)
359
-
360
- def stft(self, x):
361
- """Perform STFT.
362
- Args:
363
- x (Tensor): Input signal tensor (B, T).
364
-
365
- Returns:
366
- Tensor: x_mag, x_phs
367
- Magnitude and phase spectra (B, fft_size // 2 + 1, frames).
368
- """
369
- x_stft = torch.stft(
370
- x,
371
- self.fft_size,
372
- self.hop_size,
373
- self.win_length,
374
- self.window,
375
- return_complex=True,
376
- )
377
- x_mag = torch.sqrt(
378
- torch.clamp((x_stft.real**2) + (x_stft.imag**2), min=self.eps)
379
- )
380
-
381
- # torch.angle is expensive, so it is only evaluated if the values are used in the loss
382
- if self.phs_used:
383
- x_phs = torch.angle(x_stft)
384
- else:
385
- x_phs = None
386
-
387
- return x_mag, x_phs
388
-
389
- def forward(self, input: torch.Tensor, target: torch.Tensor):
390
- bs, chs, seq_len = input.size()
391
-
392
- if self.perceptual_weighting: # apply optional A-weighting via FIR filter
393
- # since FIRFilter only support mono audio we will move channels to batch dim
394
- input = input.view(bs * chs, 1, -1)
395
- target = target.view(bs * chs, 1, -1)
396
-
397
- # now apply the filter to both
398
- self.prefilter.to(input.device)
399
- input, target = self.prefilter(input, target)
400
-
401
- # now move the channels back
402
- input = input.view(bs, chs, -1)
403
- target = target.view(bs, chs, -1)
404
-
405
- # compute the magnitude and phase spectra of input and target
406
- self.window = self.window.to(input.device)
407
-
408
- x_mag, x_phs = self.stft(input.view(-1, input.size(-1)))
409
- y_mag, y_phs = self.stft(target.view(-1, target.size(-1)))
410
-
411
- # apply relevant transforms
412
- if self.scale is not None:
413
- self.fb = self.fb.to(input.device)
414
- x_mag = torch.matmul(self.fb, x_mag)
415
- y_mag = torch.matmul(self.fb, y_mag)
416
-
417
- # normalize scales
418
- if self.scale_invariance:
419
- alpha = (x_mag * y_mag).sum([-2, -1]) / ((y_mag**2).sum([-2, -1]))
420
- y_mag = y_mag * alpha.unsqueeze(-1)
421
-
422
- # compute loss terms
423
- sc_mag_loss = self.spectralconv(x_mag, y_mag) if self.w_sc else 0.0
424
- log_mag_loss = self.logstft(x_mag, y_mag) if self.w_log_mag else 0.0
425
- lin_mag_loss = self.linstft(x_mag, y_mag) if self.w_lin_mag else 0.0
426
- phs_loss = torch.nn.functional.mse_loss(x_phs, y_phs) if self.phs_used else 0.0
427
-
428
- # combine loss terms
429
- loss = (
430
- (self.w_sc * sc_mag_loss)
431
- + (self.w_log_mag * log_mag_loss)
432
- + (self.w_lin_mag * lin_mag_loss)
433
- + (self.w_phs * phs_loss)
434
- )
435
-
436
- loss = apply_reduction(loss, reduction=self.reduction)
437
-
438
- if self.output == "loss":
439
- return loss
440
- elif self.output == "full":
441
- return loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss
442
-
443
- class MultiResolutionSTFTLoss(torch.nn.Module):
444
- """Multi resolution STFT loss module.
445
-
446
- See [Yamamoto et al., 2019](https://arxiv.org/abs/1910.11480)
447
-
448
- Args:
449
- fft_sizes (list): List of FFT sizes.
450
- hop_sizes (list): List of hop sizes.
451
- win_lengths (list): List of window lengths.
452
- window (str, optional): Window to apply before FFT, options include:
453
- 'hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window']
454
- Default: 'hann_window'
455
- w_sc (float, optional): Weight of the spectral convergence loss term. Default: 1.0
456
- w_log_mag (float, optional): Weight of the log magnitude loss term. Default: 1.0
457
- w_lin_mag (float, optional): Weight of the linear magnitude loss term. Default: 0.0
458
- w_phs (float, optional): Weight of the spectral phase loss term. Default: 0.0
459
- sample_rate (int, optional): Sample rate. Required when scale = 'mel'. Default: None
460
- scale (str, optional): Optional frequency scaling method, options include:
461
- ['mel', 'chroma']
462
- Default: None
463
- n_bins (int, optional): Number of mel frequency bins. Required when scale = 'mel'. Default: None.
464
- scale_invariance (bool, optional): Perform an optimal scaling of the target. Default: False
465
- """
466
-
467
- def __init__(
468
- self,
469
- fft_sizes: List[int] = [1024, 2048, 512],
470
- hop_sizes: List[int] = [120, 240, 50],
471
- win_lengths: List[int] = [600, 1200, 240],
472
- window: str = "hann_window",
473
- w_sc: float = 1.0,
474
- w_log_mag: float = 1.0,
475
- w_lin_mag: float = 0.0,
476
- w_phs: float = 0.0,
477
- sample_rate: float = None,
478
- scale: str = None,
479
- n_bins: int = None,
480
- perceptual_weighting: bool = False,
481
- scale_invariance: bool = False,
482
- **kwargs,
483
- ):
484
- super().__init__()
485
- assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) # must define all
486
- self.fft_sizes = fft_sizes
487
- self.hop_sizes = hop_sizes
488
- self.win_lengths = win_lengths
489
-
490
- self.stft_losses = torch.nn.ModuleList()
491
- for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
492
- self.stft_losses += [
493
- STFTLoss(
494
- fs,
495
- ss,
496
- wl,
497
- window,
498
- w_sc,
499
- w_log_mag,
500
- w_lin_mag,
501
- w_phs,
502
- sample_rate,
503
- scale,
504
- n_bins,
505
- perceptual_weighting,
506
- scale_invariance,
507
- **kwargs,
508
- )
509
- ]
510
-
511
- def forward(self, x, y):
512
- mrstft_loss = 0.0
513
- sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss = [], [], [], []
514
-
515
- for f in self.stft_losses:
516
- if f.output == "full": # extract just first term
517
- tmp_loss = f(x, y)
518
- mrstft_loss += tmp_loss[0]
519
- sc_mag_loss.append(tmp_loss[1])
520
- log_mag_loss.append(tmp_loss[2])
521
- lin_mag_loss.append(tmp_loss[3])
522
- phs_loss.append(tmp_loss[4])
523
- else:
524
- mrstft_loss += f(x, y)
525
-
526
- mrstft_loss /= len(self.stft_losses)
527
-
528
- if f.output == "loss":
529
- return mrstft_loss
530
- else:
531
- return mrstft_loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss
532
-
533
-
534
- class SumAndDifferenceSTFTLoss(torch.nn.Module):
535
- """Sum and difference sttereo STFT loss module.
536
-
537
- See [Steinmetz et al., 2020](https://arxiv.org/abs/2010.10291)
538
-
539
- Args:
540
- fft_sizes (List[int]): List of FFT sizes.
541
- hop_sizes (List[int]): List of hop sizes.
542
- win_lengths (List[int]): List of window lengths.
543
- window (str, optional): Window function type.
544
- w_sum (float, optional): Weight of the sum loss component. Default: 1.0
545
- w_diff (float, optional): Weight of the difference loss component. Default: 1.0
546
- perceptual_weighting (bool, optional): Apply perceptual A-weighting (Sample rate must be supplied). Default: False
547
- mel_stft (bool, optional): Use Multi-resoltuion mel spectrograms. Default: False
548
- n_mel_bins (int, optional): Number of mel bins to use when mel_stft = True. Default: 128
549
- sample_rate (float, optional): Audio sample rate. Default: None
550
- output (str, optional): Format of the loss returned.
551
- 'loss' : Return only the raw, aggregate loss term.
552
- 'full' : Return the raw loss, plus intermediate loss terms.
553
- Default: 'loss'
554
- """
555
-
556
- def __init__(
557
- self,
558
- fft_sizes: List[int],
559
- hop_sizes: List[int],
560
- win_lengths: List[int],
561
- window: str = "hann_window",
562
- w_sum: float = 1.0,
563
- w_diff: float = 1.0,
564
- output: str = "loss",
565
- **kwargs,
566
- ):
567
- super().__init__()
568
- self.sd = SumAndDifference()
569
- self.w_sum = w_sum
570
- self.w_diff = w_diff
571
- self.output = output
572
- self.mrstft = MultiResolutionSTFTLoss(
573
- fft_sizes,
574
- hop_sizes,
575
- win_lengths,
576
- window,
577
- **kwargs,
578
- )
579
-
580
- def forward(self, input: torch.Tensor, target: torch.Tensor):
581
- """This loss function assumes batched input of stereo audio in the time domain.
582
-
583
- Args:
584
- input (torch.Tensor): Input tensor with shape (batch size, 2, seq_len).
585
- target (torch.Tensor): Target tensor with shape (batch size, 2, seq_len).
586
-
587
- Returns:
588
- loss (torch.Tensor): Aggreate loss term. Only returned if output='loss'.
589
- loss (torch.Tensor), sum_loss (torch.Tensor), diff_loss (torch.Tensor):
590
- Aggregate and intermediate loss terms. Only returned if output='full'.
591
- """
592
- assert input.shape == target.shape # must have same shape
593
- bs, chs, seq_len = input.size()
594
-
595
- # compute sum and difference signals for both
596
- input_sum, input_diff = self.sd(input)
597
- target_sum, target_diff = self.sd(target)
598
-
599
- # compute error in STFT domain
600
- sum_loss = self.mrstft(input_sum, target_sum)
601
- diff_loss = self.mrstft(input_diff, target_diff)
602
- loss = ((self.w_sum * sum_loss) + (self.w_diff * diff_loss)) / 2
603
-
604
- if self.output == "loss":
605
- return loss
606
- elif self.output == "full":
607
- return loss, sum_loss, diff_loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/training/losses/losses.py DELETED
@@ -1,101 +0,0 @@
1
- import typing as tp
2
-
3
- from torch.nn import functional as F
4
- from torch import nn
5
-
6
- class LossModule(nn.Module):
7
- def __init__(self, name: str, weight: float = 1.0):
8
- super().__init__()
9
-
10
- self.name = name
11
- self.weight = weight
12
-
13
- def forward(self, info, *args, **kwargs):
14
- raise NotImplementedError
15
-
16
- class ValueLoss(LossModule):
17
- def __init__(self, key: str, name, weight: float = 1.0):
18
- super().__init__(name=name, weight=weight)
19
-
20
- self.key = key
21
-
22
- def forward(self, info):
23
- return self.weight * info[self.key]
24
-
25
- class L1Loss(LossModule):
26
- def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = 'l1_loss'):
27
- super().__init__(name=name, weight=weight)
28
-
29
- self.key_a = key_a
30
- self.key_b = key_b
31
-
32
- self.mask_key = mask_key
33
-
34
- def forward(self, info):
35
- mse_loss = F.l1_loss(info[self.key_a], info[self.key_b], reduction='none')
36
-
37
- if self.mask_key is not None and self.mask_key in info:
38
- mse_loss = mse_loss[info[self.mask_key]]
39
-
40
- mse_loss = mse_loss.mean()
41
-
42
- return self.weight * mse_loss
43
-
44
- class MSELoss(LossModule):
45
- def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = 'mse_loss'):
46
- super().__init__(name=name, weight=weight)
47
-
48
- self.key_a = key_a
49
- self.key_b = key_b
50
-
51
- self.mask_key = mask_key
52
-
53
- def forward(self, info):
54
- mse_loss = F.mse_loss(info[self.key_a], info[self.key_b], reduction='none')
55
-
56
- if self.mask_key is not None and self.mask_key in info and info[self.mask_key] is not None:
57
- mask = info[self.mask_key]
58
-
59
- if mask.ndim == 2 and mse_loss.ndim == 3:
60
- mask = mask.unsqueeze(1)
61
-
62
- if mask.shape[1] != mse_loss.shape[1]:
63
- mask = mask.repeat(1, mse_loss.shape[1], 1)
64
-
65
- mse_loss = mse_loss[mask]
66
-
67
- mse_loss = mse_loss.mean()
68
-
69
- return self.weight * mse_loss
70
-
71
- class AuralossLoss(LossModule):
72
- def __init__(self, auraloss_module, input_key: str, target_key: str, name: str, weight: float = 1):
73
- super().__init__(name, weight)
74
-
75
- self.auraloss_module = auraloss_module
76
-
77
- self.input_key = input_key
78
- self.target_key = target_key
79
-
80
- def forward(self, info):
81
- loss = self.auraloss_module(info[self.input_key], info[self.target_key])
82
-
83
- return self.weight * loss
84
-
85
- class MultiLoss(nn.Module):
86
- def __init__(self, losses: tp.List[LossModule]):
87
- super().__init__()
88
-
89
- self.losses = nn.ModuleList(losses)
90
-
91
- def forward(self, info):
92
- total_loss = 0
93
-
94
- losses = {}
95
-
96
- for loss_module in self.losses:
97
- module_loss = loss_module(info)
98
- total_loss += module_loss
99
- losses[loss_module.name] = module_loss
100
-
101
- return total_loss, losses
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/build/lib/stable_audio_tools/training/utils.py DELETED
@@ -1,111 +0,0 @@
1
- import torch
2
- import os
3
-
4
- def get_rank():
5
- """Get rank of current process."""
6
-
7
- print(os.environ.keys())
8
-
9
- if "SLURM_PROCID" in os.environ:
10
- return int(os.environ["SLURM_PROCID"])
11
-
12
- if not torch.distributed.is_available() or not torch.distributed.is_initialized():
13
- return 0
14
-
15
- return torch.distributed.get_rank()
16
-
17
- class InverseLR(torch.optim.lr_scheduler._LRScheduler):
18
- """Implements an inverse decay learning rate schedule with an optional exponential
19
- warmup. When last_epoch=-1, sets initial lr as lr.
20
- inv_gamma is the number of steps/epochs required for the learning rate to decay to
21
- (1 / 2)**power of its original value.
22
- Args:
23
- optimizer (Optimizer): Wrapped optimizer.
24
- inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1.
25
- power (float): Exponential factor of learning rate decay. Default: 1.
26
- warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
27
- Default: 0.
28
- final_lr (float): The final learning rate. Default: 0.
29
- last_epoch (int): The index of last epoch. Default: -1.
30
- verbose (bool): If ``True``, prints a message to stdout for
31
- each update. Default: ``False``.
32
- """
33
-
34
- def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., final_lr=0.,
35
- last_epoch=-1, verbose=False):
36
- self.inv_gamma = inv_gamma
37
- self.power = power
38
- if not 0. <= warmup < 1:
39
- raise ValueError('Invalid value for warmup')
40
- self.warmup = warmup
41
- self.final_lr = final_lr
42
- super().__init__(optimizer, last_epoch, verbose)
43
-
44
- def get_lr(self):
45
- if not self._get_lr_called_within_step:
46
- import warnings
47
- warnings.warn("To get the last learning rate computed by the scheduler, "
48
- "please use `get_last_lr()`.")
49
-
50
- return self._get_closed_form_lr()
51
-
52
- def _get_closed_form_lr(self):
53
- warmup = 1 - self.warmup ** (self.last_epoch + 1)
54
- lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power
55
- return [warmup * max(self.final_lr, base_lr * lr_mult)
56
- for base_lr in self.base_lrs]
57
-
58
- def copy_state_dict(model, state_dict):
59
- """Load state_dict to model, but only for keys that match exactly.
60
-
61
- Args:
62
- model (nn.Module): model to load state_dict.
63
- state_dict (OrderedDict): state_dict to load.
64
- """
65
- model_state_dict = model.state_dict()
66
- for key in state_dict:
67
- if key in model_state_dict and state_dict[key].shape == model_state_dict[key].shape:
68
- if isinstance(state_dict[key], torch.nn.Parameter):
69
- # backwards compatibility for serialized parameters
70
- state_dict[key] = state_dict[key].data
71
- model_state_dict[key] = state_dict[key]
72
-
73
- model.load_state_dict(model_state_dict, strict=False)
74
-
75
- def create_optimizer_from_config(optimizer_config, parameters):
76
- """Create optimizer from config.
77
-
78
- Args:
79
- parameters (iterable): parameters to optimize.
80
- optimizer_config (dict): optimizer config.
81
-
82
- Returns:
83
- torch.optim.Optimizer: optimizer.
84
- """
85
-
86
- optimizer_type = optimizer_config["type"]
87
-
88
- if optimizer_type == "FusedAdam":
89
- from deepspeed.ops.adam import FusedAdam
90
- optimizer = FusedAdam(parameters, **optimizer_config["config"])
91
- else:
92
- optimizer_fn = getattr(torch.optim, optimizer_type)
93
- optimizer = optimizer_fn(parameters, **optimizer_config["config"])
94
- return optimizer
95
-
96
- def create_scheduler_from_config(scheduler_config, optimizer):
97
- """Create scheduler from config.
98
-
99
- Args:
100
- scheduler_config (dict): scheduler config.
101
- optimizer (torch.optim.Optimizer): optimizer.
102
-
103
- Returns:
104
- torch.optim.lr_scheduler._LRScheduler: scheduler.
105
- """
106
- if scheduler_config["type"] == "InverseLR":
107
- scheduler_fn = InverseLR
108
- else:
109
- scheduler_fn = getattr(torch.optim.lr_scheduler, scheduler_config["type"])
110
- scheduler = scheduler_fn(optimizer, **scheduler_config["config"])
111
- return scheduler
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/config_adapter.json DELETED
@@ -1,124 +0,0 @@
1
- {
2
- "model_type": "diffusion_cond",
3
- "sample_size": 2097152,
4
- "sample_rate": 44100,
5
- "audio_channels": 2,
6
- "model": {
7
- "pretransform": {
8
- "type": "autoencoder",
9
- "iterate_batch": true,
10
- "config": {
11
- "encoder": {
12
- "type": "oobleck",
13
- "requires_grad": false,
14
- "config": {
15
- "in_channels": 2,
16
- "channels": 128,
17
- "c_mults": [1, 2, 4, 8, 16],
18
- "strides": [2, 4, 4, 8, 8],
19
- "latent_dim": 128,
20
- "use_snake": true
21
- }
22
- },
23
- "decoder": {
24
- "type": "oobleck",
25
- "config": {
26
- "out_channels": 2,
27
- "channels": 128,
28
- "c_mults": [1, 2, 4, 8, 16],
29
- "strides": [2, 4, 4, 8, 8],
30
- "latent_dim": 64,
31
- "use_snake": true,
32
- "final_tanh": false
33
- }
34
- },
35
- "bottleneck": {
36
- "type": "vae"
37
- },
38
- "latent_dim": 64,
39
- "downsampling_ratio": 2048,
40
- "io_channels": 2
41
- }
42
- },
43
- "conditioning": {
44
- "configs": [
45
- {
46
- "id": "prompt",
47
- "type": "t5",
48
- "config": {
49
- "t5_model_name": "t5-base",
50
- "max_length": 128
51
- }
52
- },
53
- {
54
- "id": "seconds_start",
55
- "type": "number",
56
- "config": {
57
- "min_val": 0,
58
- "max_val": 512
59
- }
60
- },
61
- {
62
- "id": "seconds_total",
63
- "type": "number",
64
- "config": {
65
- "min_val": 0,
66
- "max_val": 512
67
- }
68
- }
69
- ],
70
- "cond_dim": 768
71
- },
72
- "diffusion": {
73
- "cross_attention_cond_ids": ["prompt", "seconds_start", "seconds_total"],
74
- "global_cond_ids": ["seconds_start", "seconds_total"],
75
- "type": "dit",
76
- "config": {
77
- "io_channels": 64,
78
- "embed_dim": 1536,
79
- "depth": 24,
80
- "num_heads": 24,
81
- "cond_token_dim": 768,
82
- "global_cond_dim": 1536,
83
- "project_cond_tokens": false,
84
- "transformer_type": "continuous_transformer",
85
- "adapter_present": true
86
- }
87
- },
88
- "io_channels": 64
89
- },
90
- "training": {
91
- "use_ema": true,
92
- "log_loss_info": false,
93
- "optimizer_configs": {
94
- "diffusion": {
95
- "adapter_present": true,
96
- "optimizer": {
97
- "type": "AdamW",
98
- "config": {
99
- "lr": 3e-3,
100
- "betas": [0.9, 0.999],
101
- "weight_decay": 1e-3
102
- }
103
- },
104
- "scheduler": {
105
- "type": "InverseLR",
106
- "config": {
107
- "inv_gamma": 1000000,
108
- "power": 0.5,
109
- "warmup": 0.99
110
- }
111
- }
112
- }
113
- },
114
- "demo": {
115
- "demo_every": 15,
116
- "demo_steps": 250,
117
- "num_demos": 1,
118
- "demo_cond": [
119
- {"prompt": "Amen break 174 BPM", "seconds_start": 0, "seconds_total": 12}
120
- ],
121
- "demo_cfg_scales": [7]
122
- }
123
- }
124
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable/convert_json.py DELETED
@@ -1,44 +0,0 @@
1
- import json
2
- import sys
3
-
4
- def update_path_in_json(input_file, output_file, new_path):
5
- # Read the input JSON file
6
- try:
7
- with open(input_file, 'r') as infile:
8
- data = json.load(infile)
9
- except FileNotFoundError:
10
- print(f"Input file {input_file} not found.")
11
- sys.exit(1)
12
- except json.JSONDecodeError:
13
- print(f"Error decoding JSON from the input file {input_file}.")
14
- sys.exit(1)
15
-
16
- # Update the path
17
- try:
18
- data['datasets'][0]['path'] = new_path
19
- except KeyError as e:
20
- print(f"Key error: {e}")
21
- sys.exit(1)
22
- except IndexError as e:
23
- print(f"Index error: {e}")
24
- sys.exit(1)
25
-
26
- # Write the updated JSON to the output file
27
- try:
28
- with open(output_file, 'w') as outfile:
29
- json.dump(data, outfile, indent=4)
30
- except IOError as e:
31
- print(f"Error writing to the output file {output_file}: {e}")
32
- sys.exit(1)
33
-
34
- print(f"Path updated successfully in {output_file}")
35
-
36
-
37
- if __name__ == "__main__":
38
- import argparse
39
- parser = argparse.ArgumentParser(description='Convert JSON for fine-tuning.')
40
- parser.add_argument('--input_json', type=str, help='Name of the dataset', required=True)
41
- parser.add_argument('--output_json', type=str, help='Path to the input CSV', required=True)
42
- parser.add_argument('--new_path', type=str, help='Path to output JSON', required=True)
43
- args = parser.parse_args()
44
- update_path_in_json(args.input_json, args.output_json, args.new_path)