diff --git a/.github/actions/audiocraft_build/action.yml b/.github/actions/audiocraft_build/action.yml index be5dae26afef4c5e756135cbddfab034db6a016e..b412cd02c6fad53a23068fe89a961f4da2db9b9f 100644 --- a/.github/actions/audiocraft_build/action.yml +++ b/.github/actions/audiocraft_build/action.yml @@ -21,6 +21,8 @@ runs: python3 -m venv env . env/bin/activate python -m pip install --upgrade pip + pip install torch torchvision torchaudio + pip install xformers pip install -e '.[dev]' - name: System Dependencies shell: bash diff --git a/.github/workflows/audiocraft_docs.yml b/.github/workflows/audiocraft_docs.yml index 668498cd718bdceaf5355e96ed51d6d25d0f61ab..96340dffe809e06f13c896b7d823f89b168221c2 100644 --- a/.github/workflows/audiocraft_docs.yml +++ b/.github/workflows/audiocraft_docs.yml @@ -23,9 +23,9 @@ jobs: - name: Make docs run: | . env/bin/activate - make docs - git add -f docs - git commit -m docs + make api_docs + git add -f api_docs + git commit -m api_docs - name: Push branch run: | diff --git a/.github/workflows/audiocraft_tests.yml b/.github/workflows/audiocraft_tests.yml index d14be656d51b49e58829d8d423257ae65cb02fd6..829b37aa777c52c2bd35ee2b0baf2af6c02a5962 100644 --- a/.github/workflows/audiocraft_tests.yml +++ b/.github/workflows/audiocraft_tests.yml @@ -12,6 +12,11 @@ jobs: steps: - uses: actions/checkout@v2 - uses: ./.github/actions/audiocraft_build - - run: | + - name: Run unit tests + run: | . env/bin/activate make tests + - name: Run integration tests + run: | + . env/bin/activate + make tests_integ diff --git a/.gitignore b/.gitignore index 85968eaac10faa4c0180acf8038c18bfd92bb369..40392acbd02a75a755ab1e038cd7e8457663b94d 100644 --- a/.gitignore +++ b/.gitignore @@ -35,7 +35,7 @@ wheels/ .coverage # docs -/docs +/api_docs # dotenv .env @@ -46,6 +46,13 @@ wheels/ venv/ ENV/ +# egs with manifest files +egs/* +!egs/example +# local datasets +dataset/* +!dataset/example + # personal notebooks & scripts */local_scripts */notes diff --git a/CHANGELOG.md b/CHANGELOG.md index 24fc214df236b40efead4b1585b01632d9658e9b..6036b72f02e56858adc3d564451aaf4d2d175da1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,37 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). -## [0.0.2a] - TBD +## [1.2.0a] - TBD + +Adding stereo models. + + +## [1.1.0] - 2023-11-06 + +Not using torchaudio anymore when writing audio files, relying instead directly on the commandline ffmpeg. Also not using it anymore for reading audio files, for similar reasons. + +Fixed DAC support with non default number of codebooks. + +Fixed bug when `two_step_cfg` was overriden when calling `generate()`. + +Fixed samples being always prompted with audio, rather than having both prompted and unprompted. + +**Backward incompatible change:** A `torch.no_grad` around the computation of the conditioning made its way in the public release. + The released models were trained without this. Those impact linear layers applied to the output of the T5 or melody conditioners. + We removed it, so you might need to retrain models. + +**Backward incompatible change:** Fixing wrong sample rate in CLAP (WARNING if you trained model with CLAP before). + +**Backward incompatible change:** Renamed VALLEPattern to CoarseFirstPattern, as it was wrongly named. Probably no one + retrained a model with this pattern, so hopefully this won't impact you! + + +## [1.0.0] - 2023-09-07 + +Major revision, added training code for EnCodec, AudioGen, MusicGen, and MultiBandDiffusion. +Added pretrained model for AudioGen and MultiBandDiffusion. + +## [0.0.2] - 2023-08-01 Improved demo, fixed top p (thanks @jnordberg). diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 55b99140204d785d572ada9761dd77f302ae31c6..a3e9507643d4439f509a8fc8b87dc73417ef9822 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,11 +1,11 @@ -# Contributing to Audiocraft +# Contributing to AudioCraft We want to make contributing to this project as easy and transparent as possible. ## Pull Requests -Audiocraft is the implementation of a research paper. +AudioCraft is the implementation of a research paper. Therefore, we do not plan on accepting many pull requests for new features. We certainly welcome them for bug fixes. diff --git a/LICENSE_weights b/LICENSE_weights index dc1adf98654156baeb94d2e055c224a847e5820d..108b5f002fc31efe11d881de2cd05329ebe8cc37 100644 --- a/LICENSE_weights +++ b/LICENSE_weights @@ -1,157 +1,399 @@ -# Attribution-NonCommercial-NoDerivatives 4.0 International - -> *Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.* -> -> ### Using Creative Commons Public Licenses -> -> Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses. -> -> * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors). -> -> * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees). - -## Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International Public License - -By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions. - -### Section 1 – Definitions. - -a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image. - -b. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights. - -e. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements. - -f. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material. - -h. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License. - -i. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license. - -h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License. - -i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange. - -j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them. - -k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world. - -l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning. - -### Section 2 – Scope. - -a. ___License grant.___ - - 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to: - - A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and - - B. produce and reproduce, but not Share, Adapted Material for NonCommercial purposes only. - - 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions. - - 3. __Term.__ The term of this Public License is specified in Section 6(a). - - 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material. - - 5. __Downstream recipients.__ - - A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License. - - B. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material. - - 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i). - -b. ___Other rights.___ - - 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise. - - 2. Patent and trademark rights are not licensed under this Public License. - - 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes. - -### Section 3 – License Conditions. - -Your exercise of the Licensed Rights is expressly made subject to the following conditions. - -a. ___Attribution.___ - - 1. If You Share the Licensed Material, You must: - - A. retain the following if it is supplied by the Licensor with the Licensed Material: - - i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated); - - ii. a copyright notice; - - iii. a notice that refers to this Public License; - - iv. a notice that refers to the disclaimer of warranties; - - v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable; - - B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and - - C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License. - - For the avoidance of doubt, You do not have permission under this Public License to Share Adapted Material. - - 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information. - - 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable. - -### Section 4 – Sui Generis Database Rights. - -Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material: - -a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only and provided You do not Share Adapted Material; - -b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and - -c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database. - -For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights. - -### Section 5 – Disclaimer of Warranties and Limitation of Liability. - -a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__ - -b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__ - -c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability. - -### Section 6 – Term and Termination. - -a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically. - -b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates: - - 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or - - 2. upon express reinstatement by the Licensor. - - For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License. - -c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License. - -d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License. - -### Section 7 – Other Terms and Conditions. - -a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed. - -b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License. - -### Section 8 – Interpretation. - -a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License. - -b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions. - -c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor. - -d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority. - -> Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses. -> -> Creative Commons may be contacted at [creativecommons.org](http://creativecommons.org). +Attribution-NonCommercial 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More_considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial 4.0 International Public +License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial 4.0 International Public License ("Public +License"). To the extent this Public License may be interpreted as a +contract, You are granted the Licensed Rights in consideration of Your +acceptance of these terms and conditions, and the Licensor grants You +such rights in consideration of benefits the Licensor receives from +making the Licensed Material available under these terms and +conditions. + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + d. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + e. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + f. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + g. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + h. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + i. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + j. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + k. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + l. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's + License You apply must not prevent recipients of the Adapted + Material from complying with this Public License. + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. diff --git a/MANIFEST.in b/MANIFEST.in index e0dbcccf0e6c75e59bd81f108062f0b312ee93b9..4bfcf45c4e63ce27640e58cd4cde337e6d299844 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -6,3 +6,10 @@ include *.ini include requirements.txt include audiocraft/py.typed include assets/*.mp3 +include datasets/*.mp3 +recursive-include config *.yaml +recursive-include demos *.py +recursive-include demos *.ipynb +recursive-include scripts *.py +recursive-include model_cards *.md +recursive-include docs *.md diff --git a/Makefile b/Makefile index 5bfd89dd833d7448b21073eb6ee7cfac1d5157dd..3a4910066583dc22f06f5ec2d5711367c941c86b 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,15 @@ +INTEG=AUDIOCRAFT_DORA_DIR="/tmp/magma_$(USER)" python3 -m dora -v run --clear device=cpu dataset.num_workers=0 optim.epochs=1 \ + dataset.train.num_samples=10 dataset.valid.num_samples=10 \ + dataset.evaluate.num_samples=10 dataset.generate.num_samples=2 sample_rate=16000 \ + logging.level=DEBUG +INTEG_COMPRESSION = $(INTEG) solver=compression/debug rvq.n_q=2 rvq.bins=48 checkpoint.save_last=true # SIG is 5091833e +INTEG_MUSICGEN = $(INTEG) solver=musicgen/debug dset=audio/example compression_model_checkpoint=//sig/5091833e \ + transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false # Using compression model from 5091833e +INTEG_AUDIOGEN = $(INTEG) solver=audiogen/debug dset=audio/example compression_model_checkpoint=//sig/5091833e \ + transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false # Using compression model from 5091833e +INTEG_MBD = $(INTEG) solver=diffusion/debug dset=audio/example \ + checkpoint.save_last=false # Using compression model from 616d7b3c + default: linter tests install: @@ -10,12 +22,19 @@ linter: tests: coverage run -m pytest tests - coverage report --include 'audiocraft/*' + coverage report + +tests_integ: + $(INTEG_COMPRESSION) + $(INTEG_MBD) + $(INTEG_MUSICGEN) + $(INTEG_AUDIOGEN) + -docs: - pdoc3 --html -o docs -f audiocraft +api_docs: + pdoc3 --html -o api_docs -f audiocraft dist: python setup.py sdist -.PHONY: linter tests docs dist +.PHONY: linter tests api_docs dist diff --git a/README.md b/README.md index f798eab150e85eb88334c171a0223cab612043be..6c445e7dc908b8edeef39f2a4f44658c58113115 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ tags: - "music generation" - "language models" - "LLMs" -app_file: "app.py" +app_file: "demos/musicgen_app.py" emoji: 🎵 colorFrom: gray colorTo: blue @@ -14,33 +14,17 @@ sdk_version: 3.34.0 pinned: true license: "cc-by-nc-4.0" --- -# Audiocraft +# AudioCraft ![docs badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_docs/badge.svg) ![linter badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_linter/badge.svg) ![tests badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_tests/badge.svg) -Audiocraft is a PyTorch library for deep learning research on audio generation. At the moment, it contains the code for MusicGen, a state-of-the-art controllable text-to-music model. +AudioCraft is a PyTorch library for deep learning research on audio generation. AudioCraft contains inference and training code +for two state-of-the-art AI generative models producing high-quality audio: AudioGen and MusicGen. -## MusicGen - -Audiocraft provides the code and models for MusicGen, [a simple and controllable model for music generation][arxiv]. MusicGen is a single stage auto-regressive -Transformer model trained over a 32kHz EnCodec tokenizer with 4 codebooks sampled at 50 Hz. Unlike existing methods like [MusicLM](https://arxiv.org/abs/2301.11325), MusicGen doesn't require a self-supervised semantic representation, and it generates -all 4 codebooks in one pass. By introducing a small delay between the codebooks, we show we can predict -them in parallel, thus having only 50 auto-regressive steps per second of audio. -Check out our [sample page][musicgen_samples] or test the available demo! - - - Open In Colab - - - Open in HugginFace - -
- -We use 20K hours of licensed music to train MusicGen. Specifically, we rely on an internal dataset of 10K high-quality music tracks, and on the ShutterStock and Pond5 music data. ## Installation -Audiocraft requires Python 3.9, PyTorch 2.0.0, and a GPU with at least 16 GB of memory (for the medium-sized model). To install Audiocraft, you can run the following: +AudioCraft requires Python 3.9, PyTorch 2.0.0. To install AudioCraft, you can run the following: ```shell # Best to make sure you have torch installed first, in particular before installing xformers. @@ -49,92 +33,68 @@ pip install 'torch>=2.0' # Then proceed to one of the following pip install -U audiocraft # stable release pip install -U git+https://git@github.com/facebookresearch/audiocraft#egg=audiocraft # bleeding edge -pip install -e . # or if you cloned the repo locally +pip install -e . # or if you cloned the repo locally (mandatory if you want to train). ``` -## Usage -We offer a number of way to interact with MusicGen: -1. A demo is also available on the [`facebook/MusicGen` HuggingFace Space](https://huggingface.co./spaces/facebook/MusicGen) (huge thanks to all the HF team for their support). -2. You can run the Gradio demo in Colab: [colab notebook](https://colab.research.google.com/drive/1fxGqfg96RBUvGxZ1XXN07s3DthrKUl4-?usp=sharing). -3. You can use the gradio demo locally by running `python app.py`. -4. You can play with MusicGen by running the jupyter notebook at [`demo.ipynb`](./demo.ipynb) locally (if you have a GPU). -5. Finally, checkout [@camenduru Colab page](https://github.com/camenduru/MusicGen-colab) which is regularly - updated with contributions from @camenduru and the community. - -## API - -We provide a simple API and 4 pre-trained models. The pre trained models are: -- `small`: 300M model, text to music only - [🤗 Hub](https://huggingface.co./facebook/musicgen-small) -- `medium`: 1.5B model, text to music only - [🤗 Hub](https://huggingface.co./facebook/musicgen-medium) -- `melody`: 1.5B model, text to music and text+melody to music - [🤗 Hub](https://huggingface.co./facebook/musicgen-melody) -- `large`: 3.3B model, text to music only - [🤗 Hub](https://huggingface.co./facebook/musicgen-large) - -We observe the best trade-off between quality and compute with the `medium` or `melody` model. -In order to use MusicGen locally **you must have a GPU**. We recommend 16GB of memory, but smaller -GPUs will be able to generate short sequences, or longer sequences with the `small` model. - -**Note**: Please make sure to have [ffmpeg](https://ffmpeg.org/download.html) installed when using newer version of `torchaudio`. -You can install it with: -``` -apt-get install ffmpeg +We also recommend having `ffmpeg` installed, either through your system or Anaconda: +```bash +sudo apt-get install ffmpeg +# Or if you are using Anaconda or Miniconda +conda install "ffmpeg<5" -c conda-forge ``` -See after a quick example for using the API. +## Models -```python -import torchaudio -from audiocraft.models import MusicGen -from audiocraft.data.audio import audio_write +At the moment, AudioCraft contains the training code and inference code for: +* [MusicGen](./docs/MUSICGEN.md): A state-of-the-art controllable text-to-music model. +* [AudioGen](./docs/AUDIOGEN.md): A state-of-the-art text-to-sound model. +* [EnCodec](./docs/ENCODEC.md): A state-of-the-art high fidelity neural audio codec. +* [Multi Band Diffusion](./docs/MBD.md): An EnCodec compatible decoder using diffusion. -model = MusicGen.get_pretrained('melody') -model.set_generation_params(duration=8) # generate 8 seconds. -wav = model.generate_unconditional(4) # generates 4 unconditional audio samples -descriptions = ['happy rock', 'energetic EDM', 'sad jazz'] -wav = model.generate(descriptions) # generates 3 samples. +## Training code -melody, sr = torchaudio.load('./assets/bach.mp3') -# generates using the melody from the given audio and the provided descriptions. -wav = model.generate_with_chroma(descriptions, melody[None].expand(3, -1, -1), sr) +AudioCraft contains PyTorch components for deep learning research in audio and training pipelines for the developed models. +For a general introduction of AudioCraft design principles and instructions to develop your own training pipeline, refer to +the [AudioCraft training documentation](./docs/TRAINING.md). -for idx, one_wav in enumerate(wav): - # Will save under {idx}.wav, with loudness normalization at -14 db LUFS. - audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True) -``` +For reproducing existing work and using the developed training pipelines, refer to the instructions for each specific model +that provides pointers to configuration, example grids and model/task-specific information and FAQ. -## Model Card +## API documentation -See [the model card page](./MODEL_CARD.md). +We provide some [API documentation](https://facebookresearch.github.io/audiocraft/api_docs/audiocraft/index.html) for AudioCraft. -## FAQ -#### Will the training code be released? +## FAQ -Yes. We will soon release the training code for MusicGen and EnCodec. +#### Is the training code available? +Yes! We provide the training code for [EnCodec](./docs/ENCODEC.md), [MusicGen](./docs/MUSICGEN.md) and [Multi Band Diffusion](./docs/MBD.md). -#### I need help on Windows +#### Where are the models stored? -@FurkanGozukara made a complete tutorial for [Audiocraft/MusicGen on Windows](https://youtu.be/v-YpvPkhdO4) +Hugging Face stored the model in a specific location, which can be overriden by setting the `AUDIOCRAFT_CACHE_DIR` environment variable for the AudioCraft models. +In order to change the cache location of the other Hugging Face models, please check out the [Hugging Face Transformers documentation for the cache setup](https://huggingface.co./docs/transformers/installation#cache-setup). +Finally, if you use a model that relies on Demucs (e.g. `musicgen-melody`) and want to change the download location for Demucs, refer to the [Torch Hub documentation](https://pytorch.org/docs/stable/hub.html#where-are-my-downloaded-models-saved). -#### I need help for running the demo on Colab -Check [@camenduru tutorial on Youtube](https://www.youtube.com/watch?v=EGfxuTy9Eeo). +## License +* The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE). +* The models weights in this repository are released under the CC-BY-NC 4.0 license as found in the [LICENSE_weights file](LICENSE_weights). ## Citation + +For the general framework of AudioCraft, please cite the following. ``` @article{copet2023simple, - title={Simple and Controllable Music Generation}, - author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez}, - year={2023}, - journal={arXiv preprint arXiv:2306.05284}, + title={Simple and Controllable Music Generation}, + author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez}, + year={2023}, + journal={arXiv preprint arXiv:2306.05284}, } ``` -## License -* The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE). -* The weights in this repository are released under the CC-BY-NC 4.0 license as found in the [LICENSE_weights file](LICENSE_weights). - -[arxiv]: https://arxiv.org/abs/2306.05284 -[musicgen_samples]: https://ai.honu.io/papers/musicgen/ +When referring to a specific model, please cite as mentioned in the model specific README, e.g +[./docs/MUSICGEN.md](./docs/MUSICGEN.md), [./docs/AUDIOGEN.md](./docs/AUDIOGEN.md), etc. diff --git a/assets/a_duck_quacking_as_birds_chirp_and_a_pigeon_cooing.mp3 b/assets/a_duck_quacking_as_birds_chirp_and_a_pigeon_cooing.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..71be35a12d3e97993996806d6a94175568b2761f Binary files /dev/null and b/assets/a_duck_quacking_as_birds_chirp_and_a_pigeon_cooing.mp3 differ diff --git a/assets/sirens_and_a_humming_engine_approach_and_pass.mp3 b/assets/sirens_and_a_humming_engine_approach_and_pass.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..e74b5b61a624fbf69f5e70febc64c91658bb38ac Binary files /dev/null and b/assets/sirens_and_a_humming_engine_approach_and_pass.mp3 differ diff --git a/audiocraft/__init__.py b/audiocraft/__init__.py index 6b8594f470200ff5c000542ef115375ed69b749c..251afe7df493d05d205e526787479f6b9d1f7964 100644 --- a/audiocraft/__init__.py +++ b/audiocraft/__init__.py @@ -3,8 +3,24 @@ # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +""" +AudioCraft is a general framework for training audio generative models. +At the moment we provide the training code for: + +- [MusicGen](https://arxiv.org/abs/2306.05284), a state-of-the-art + text-to-music and melody+text autoregressive generative model. + For the solver, see `audiocraft.solvers.musicgen.MusicGenSolver`, and for the model, + `audiocraft.models.musicgen.MusicGen`. +- [AudioGen](https://arxiv.org/abs/2209.15352), a state-of-the-art + text-to-general-audio generative model. +- [EnCodec](https://arxiv.org/abs/2210.13438), efficient and high fidelity + neural audio codec which provides an excellent tokenizer for autoregressive language models. + See `audiocraft.solvers.compression.CompressionSolver`, and `audiocraft.models.encodec.EncodecModel`. +- [MultiBandDiffusion](TODO), alternative diffusion-based decoder compatible with EnCodec that + improves the perceived quality and reduces the artifacts coming from adversarial decoders. +""" # flake8: noqa from . import data, modules, models -__version__ = '0.0.2a2' +__version__ = '1.1.0' diff --git a/audiocraft/adversarial/__init__.py b/audiocraft/adversarial/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..864058706fbfae13d7f7dc850cc411a2f27d1510 --- /dev/null +++ b/audiocraft/adversarial/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Adversarial losses and discriminator architectures.""" + +# flake8: noqa +from .discriminators import ( + MultiPeriodDiscriminator, + MultiScaleDiscriminator, + MultiScaleSTFTDiscriminator +) +from .losses import ( + AdversarialLoss, + AdvLossType, + get_adv_criterion, + get_fake_criterion, + get_real_criterion, + FeatLossType, + FeatureMatchingLoss +) diff --git a/audiocraft/adversarial/discriminators/__init__.py b/audiocraft/adversarial/discriminators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f9e5ff59950ee0b1d1a67c9b3831d67d08048148 --- /dev/null +++ b/audiocraft/adversarial/discriminators/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# flake8: noqa +from .mpd import MultiPeriodDiscriminator +from .msd import MultiScaleDiscriminator +from .msstftd import MultiScaleSTFTDiscriminator diff --git a/audiocraft/adversarial/discriminators/base.py b/audiocraft/adversarial/discriminators/base.py new file mode 100644 index 0000000000000000000000000000000000000000..a9d517e9f5bf0f4e18252c45c8db3a35a7255f69 --- /dev/null +++ b/audiocraft/adversarial/discriminators/base.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +import typing as tp + +import torch +import torch.nn as nn + + +FeatureMapType = tp.List[torch.Tensor] +LogitsType = torch.Tensor +MultiDiscriminatorOutputType = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]] + + +class MultiDiscriminator(ABC, nn.Module): + """Base implementation for discriminators composed of sub-discriminators acting at different scales. + """ + def __init__(self): + super().__init__() + + @abstractmethod + def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType: + ... + + @property + @abstractmethod + def num_discriminators(self) -> int: + """Number of discriminators. + """ + ... diff --git a/audiocraft/adversarial/discriminators/mpd.py b/audiocraft/adversarial/discriminators/mpd.py new file mode 100644 index 0000000000000000000000000000000000000000..8debd1fa72d77ca03df680facb60bdf79638cade --- /dev/null +++ b/audiocraft/adversarial/discriminators/mpd.py @@ -0,0 +1,106 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import typing as tp + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...modules import NormConv2d +from .base import MultiDiscriminator, MultiDiscriminatorOutputType + + +def get_padding(kernel_size: int, dilation: int = 1) -> int: + return int((kernel_size * dilation - dilation) / 2) + + +class PeriodDiscriminator(nn.Module): + """Period sub-discriminator. + + Args: + period (int): Period between samples of audio. + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + n_layers (int): Number of convolutional layers. + kernel_sizes (list of int): Kernel sizes for convolutions. + stride (int): Stride for convolutions. + filters (int): Initial number of filters in convolutions. + filters_scale (int): Multiplier of number of filters as we increase depth. + max_filters (int): Maximum number of filters. + norm (str): Normalization method. + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function. + """ + def __init__(self, period: int, in_channels: int = 1, out_channels: int = 1, + n_layers: int = 5, kernel_sizes: tp.List[int] = [5, 3], stride: int = 3, + filters: int = 8, filters_scale: int = 4, max_filters: int = 1024, + norm: str = 'weight_norm', activation: str = 'LeakyReLU', + activation_params: dict = {'negative_slope': 0.2}): + super().__init__() + self.period = period + self.n_layers = n_layers + self.activation = getattr(torch.nn, activation)(**activation_params) + self.convs = nn.ModuleList() + in_chs = in_channels + for i in range(self.n_layers): + out_chs = min(filters * (filters_scale ** (i + 1)), max_filters) + eff_stride = 1 if i == self.n_layers - 1 else stride + self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_sizes[0], 1), stride=(eff_stride, 1), + padding=((kernel_sizes[0] - 1) // 2, 0), norm=norm)) + in_chs = out_chs + self.conv_post = NormConv2d(in_chs, out_channels, kernel_size=(kernel_sizes[1], 1), stride=1, + padding=((kernel_sizes[1] - 1) // 2, 0), norm=norm) + + def forward(self, x: torch.Tensor): + fmap = [] + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), 'reflect') + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for conv in self.convs: + x = conv(x) + x = self.activation(x) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + # x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(MultiDiscriminator): + """Multi-Period (MPD) Discriminator. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + periods (Sequence[int]): Periods between samples of audio for the sub-discriminators. + **kwargs: Additional args for `PeriodDiscriminator` + """ + def __init__(self, in_channels: int = 1, out_channels: int = 1, + periods: tp.Sequence[int] = [2, 3, 5, 7, 11], **kwargs): + super().__init__() + self.discriminators = nn.ModuleList([ + PeriodDiscriminator(p, in_channels, out_channels, **kwargs) for p in periods + ]) + + @property + def num_discriminators(self): + return len(self.discriminators) + + def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType: + logits = [] + fmaps = [] + for disc in self.discriminators: + logit, fmap = disc(x) + logits.append(logit) + fmaps.append(fmap) + return logits, fmaps diff --git a/audiocraft/adversarial/discriminators/msd.py b/audiocraft/adversarial/discriminators/msd.py new file mode 100644 index 0000000000000000000000000000000000000000..c4e67e29b46ab22f6ffeec85ffc64d8b99800b1b --- /dev/null +++ b/audiocraft/adversarial/discriminators/msd.py @@ -0,0 +1,126 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import typing as tp + +import numpy as np +import torch +import torch.nn as nn + +from ...modules import NormConv1d +from .base import MultiDiscriminator, MultiDiscriminatorOutputType + + +class ScaleDiscriminator(nn.Module): + """Waveform sub-discriminator. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_sizes (Sequence[int]): Kernel sizes for first and last convolutions. + filters (int): Number of initial filters for convolutions. + max_filters (int): Maximum number of filters. + downsample_scales (Sequence[int]): Scale for downsampling implemented as strided convolutions. + inner_kernel_sizes (Sequence[int] or None): Kernel sizes for inner convolutions. + groups (Sequence[int] or None): Groups for inner convolutions. + strides (Sequence[int] or None): Strides for inner convolutions. + paddings (Sequence[int] or None): Paddings for inner convolutions. + norm (str): Normalization method. + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function. + pad (str): Padding for initial convolution. + pad_params (dict): Parameters to provide to the padding module. + """ + def __init__(self, in_channels=1, out_channels=1, kernel_sizes: tp.Sequence[int] = [5, 3], + filters: int = 16, max_filters: int = 1024, downsample_scales: tp.Sequence[int] = [4, 4, 4, 4], + inner_kernel_sizes: tp.Optional[tp.Sequence[int]] = None, groups: tp.Optional[tp.Sequence[int]] = None, + strides: tp.Optional[tp.Sequence[int]] = None, paddings: tp.Optional[tp.Sequence[int]] = None, + norm: str = 'weight_norm', activation: str = 'LeakyReLU', + activation_params: dict = {'negative_slope': 0.2}, pad: str = 'ReflectionPad1d', + pad_params: dict = {}): + super().__init__() + assert len(kernel_sizes) == 2 + assert kernel_sizes[0] % 2 == 1 + assert kernel_sizes[1] % 2 == 1 + assert (inner_kernel_sizes is None or len(inner_kernel_sizes) == len(downsample_scales)) + assert (groups is None or len(groups) == len(downsample_scales)) + assert (strides is None or len(strides) == len(downsample_scales)) + assert (paddings is None or len(paddings) == len(downsample_scales)) + self.activation = getattr(torch.nn, activation)(**activation_params) + self.convs = nn.ModuleList() + self.convs.append( + nn.Sequential( + getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params), + NormConv1d(in_channels, filters, kernel_size=np.prod(kernel_sizes), stride=1, norm=norm) + ) + ) + + in_chs = filters + for i, downsample_scale in enumerate(downsample_scales): + out_chs = min(in_chs * downsample_scale, max_filters) + default_kernel_size = downsample_scale * 10 + 1 + default_stride = downsample_scale + default_padding = (default_kernel_size - 1) // 2 + default_groups = in_chs // 4 + self.convs.append( + NormConv1d(in_chs, out_chs, + kernel_size=inner_kernel_sizes[i] if inner_kernel_sizes else default_kernel_size, + stride=strides[i] if strides else default_stride, + groups=groups[i] if groups else default_groups, + padding=paddings[i] if paddings else default_padding, + norm=norm)) + in_chs = out_chs + + out_chs = min(in_chs * 2, max_filters) + self.convs.append(NormConv1d(in_chs, out_chs, kernel_size=kernel_sizes[0], stride=1, + padding=(kernel_sizes[0] - 1) // 2, norm=norm)) + self.conv_post = NormConv1d(out_chs, out_channels, kernel_size=kernel_sizes[1], stride=1, + padding=(kernel_sizes[1] - 1) // 2, norm=norm) + + def forward(self, x: torch.Tensor): + fmap = [] + for layer in self.convs: + x = layer(x) + x = self.activation(x) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + # x = torch.flatten(x, 1, -1) + return x, fmap + + +class MultiScaleDiscriminator(MultiDiscriminator): + """Multi-Scale (MSD) Discriminator, + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + downsample_factor (int): Downsampling factor between the different scales. + scale_norms (Sequence[str]): Normalization for each sub-discriminator. + **kwargs: Additional args for ScaleDiscriminator. + """ + def __init__(self, in_channels: int = 1, out_channels: int = 1, downsample_factor: int = 2, + scale_norms: tp.Sequence[str] = ['weight_norm', 'weight_norm', 'weight_norm'], **kwargs): + super().__init__() + self.discriminators = nn.ModuleList([ + ScaleDiscriminator(in_channels, out_channels, norm=norm, **kwargs) for norm in scale_norms + ]) + self.downsample = nn.AvgPool1d(downsample_factor * 2, downsample_factor, padding=downsample_factor) + + @property + def num_discriminators(self): + return len(self.discriminators) + + def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType: + logits = [] + fmaps = [] + for i, disc in enumerate(self.discriminators): + if i != 0: + self.downsample(x) + logit, fmap = disc(x) + logits.append(logit) + fmaps.append(fmap) + return logits, fmaps diff --git a/audiocraft/adversarial/discriminators/msstftd.py b/audiocraft/adversarial/discriminators/msstftd.py new file mode 100644 index 0000000000000000000000000000000000000000..81a9100961c7a89a39df2643b24268fb90bfeaa4 --- /dev/null +++ b/audiocraft/adversarial/discriminators/msstftd.py @@ -0,0 +1,134 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import typing as tp + +import torchaudio +import torch +from torch import nn +from einops import rearrange + +from ...modules import NormConv2d +from .base import MultiDiscriminator, MultiDiscriminatorOutputType + + +def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)): + return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2) + + +class DiscriminatorSTFT(nn.Module): + """STFT sub-discriminator. + + Args: + filters (int): Number of filters in convolutions. + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + n_fft (int): Size of FFT for each scale. + hop_length (int): Length of hop between STFT windows for each scale. + kernel_size (tuple of int): Inner Conv2d kernel sizes. + stride (tuple of int): Inner Conv2d strides. + dilations (list of int): Inner Conv2d dilation on the time dimension. + win_length (int): Window size for each scale. + normalized (bool): Whether to normalize by magnitude after stft. + norm (str): Normalization method. + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function. + growth (int): Growth factor for the filters. + """ + def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, + n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024, + filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4], + stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm', + activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}): + super().__init__() + assert len(kernel_size) == 2 + assert len(stride) == 2 + self.filters = filters + self.in_channels = in_channels + self.out_channels = out_channels + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.normalized = normalized + self.activation = getattr(torch.nn, activation)(**activation_params) + self.spec_transform = torchaudio.transforms.Spectrogram( + n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window, + normalized=self.normalized, center=False, pad_mode=None, power=None) + spec_channels = 2 * self.in_channels + self.convs = nn.ModuleList() + self.convs.append( + NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size)) + ) + in_chs = min(filters_scale * self.filters, max_filters) + for i, dilation in enumerate(dilations): + out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters) + self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride, + dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)), + norm=norm)) + in_chs = out_chs + out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters) + self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]), + padding=get_2d_padding((kernel_size[0], kernel_size[0])), + norm=norm)) + self.conv_post = NormConv2d(out_chs, self.out_channels, + kernel_size=(kernel_size[0], kernel_size[0]), + padding=get_2d_padding((kernel_size[0], kernel_size[0])), + norm=norm) + + def forward(self, x: torch.Tensor): + fmap = [] + z = self.spec_transform(x) # [B, 2, Freq, Frames, 2] + z = torch.cat([z.real, z.imag], dim=1) + z = rearrange(z, 'b c w t -> b c t w') + for i, layer in enumerate(self.convs): + z = layer(z) + z = self.activation(z) + fmap.append(z) + z = self.conv_post(z) + return z, fmap + + +class MultiScaleSTFTDiscriminator(MultiDiscriminator): + """Multi-Scale STFT (MS-STFT) discriminator. + + Args: + filters (int): Number of filters in convolutions. + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + sep_channels (bool): Separate channels to distinct samples for stereo support. + n_ffts (Sequence[int]): Size of FFT for each scale. + hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale. + win_lengths (Sequence[int]): Window size for each scale. + **kwargs: Additional args for STFTDiscriminator. + """ + def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, sep_channels: bool = False, + n_ffts: tp.List[int] = [1024, 2048, 512], hop_lengths: tp.List[int] = [256, 512, 128], + win_lengths: tp.List[int] = [1024, 2048, 512], **kwargs): + super().__init__() + assert len(n_ffts) == len(hop_lengths) == len(win_lengths) + self.sep_channels = sep_channels + self.discriminators = nn.ModuleList([ + DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels, + n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs) + for i in range(len(n_ffts)) + ]) + + @property + def num_discriminators(self): + return len(self.discriminators) + + def _separate_channels(self, x: torch.Tensor) -> torch.Tensor: + B, C, T = x.shape + return x.view(-1, 1, T) + + def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType: + logits = [] + fmaps = [] + for disc in self.discriminators: + logit, fmap = disc(x) + logits.append(logit) + fmaps.append(fmap) + return logits, fmaps diff --git a/audiocraft/adversarial/losses.py b/audiocraft/adversarial/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..be293e739bdc2d91273f30fb789befe7c8b49a43 --- /dev/null +++ b/audiocraft/adversarial/losses.py @@ -0,0 +1,228 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utility module to handle adversarial losses without requiring to mess up the main training loop. +""" + +import typing as tp + +import flashy +import torch +import torch.nn as nn +import torch.nn.functional as F + + +ADVERSARIAL_LOSSES = ['mse', 'hinge', 'hinge2'] + + +AdvLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor], torch.Tensor]] +FeatLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] + + +class AdversarialLoss(nn.Module): + """Adversary training wrapper. + + Args: + adversary (nn.Module): The adversary module will be used to estimate the logits given the fake and real samples. + We assume here the adversary output is ``Tuple[List[torch.Tensor], List[List[torch.Tensor]]]`` + where the first item is a list of logits and the second item is a list of feature maps. + optimizer (torch.optim.Optimizer): Optimizer used for training the given module. + loss (AdvLossType): Loss function for generator training. + loss_real (AdvLossType): Loss function for adversarial training on logits from real samples. + loss_fake (AdvLossType): Loss function for adversarial training on logits from fake samples. + loss_feat (FeatLossType): Feature matching loss function for generator training. + normalize (bool): Whether to normalize by number of sub-discriminators. + + Example of usage: + adv_loss = AdversarialLoss(adversaries, optimizer, loss, loss_real, loss_fake) + for real in loader: + noise = torch.randn(...) + fake = model(noise) + adv_loss.train_adv(fake, real) + loss, _ = adv_loss(fake, real) + loss.backward() + """ + def __init__(self, + adversary: nn.Module, + optimizer: torch.optim.Optimizer, + loss: AdvLossType, + loss_real: AdvLossType, + loss_fake: AdvLossType, + loss_feat: tp.Optional[FeatLossType] = None, + normalize: bool = True): + super().__init__() + self.adversary: nn.Module = adversary + flashy.distrib.broadcast_model(self.adversary) + self.optimizer = optimizer + self.loss = loss + self.loss_real = loss_real + self.loss_fake = loss_fake + self.loss_feat = loss_feat + self.normalize = normalize + + def _save_to_state_dict(self, destination, prefix, keep_vars): + # Add the optimizer state dict inside our own. + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + 'optimizer'] = self.optimizer.state_dict() + return destination + + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + # Load optimizer state. + self.optimizer.load_state_dict(state_dict.pop(prefix + 'optimizer')) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + def get_adversary_pred(self, x): + """Run adversary model, validating expected output format.""" + logits, fmaps = self.adversary(x) + assert isinstance(logits, list) and all([isinstance(t, torch.Tensor) for t in logits]), \ + f'Expecting a list of tensors as logits but {type(logits)} found.' + assert isinstance(fmaps, list), f'Expecting a list of features maps but {type(fmaps)} found.' + for fmap in fmaps: + assert isinstance(fmap, list) and all([isinstance(f, torch.Tensor) for f in fmap]), \ + f'Expecting a list of tensors as feature maps but {type(fmap)} found.' + return logits, fmaps + + def train_adv(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor: + """Train the adversary with the given fake and real example. + + We assume the adversary output is the following format: Tuple[List[torch.Tensor], List[List[torch.Tensor]]]. + The first item being the logits and second item being a list of feature maps for each sub-discriminator. + + This will automatically synchronize gradients (with `flashy.distrib.eager_sync_model`) + and call the optimizer. + """ + loss = torch.tensor(0., device=fake.device) + all_logits_fake_is_fake, _ = self.get_adversary_pred(fake.detach()) + all_logits_real_is_fake, _ = self.get_adversary_pred(real.detach()) + n_sub_adversaries = len(all_logits_fake_is_fake) + for logit_fake_is_fake, logit_real_is_fake in zip(all_logits_fake_is_fake, all_logits_real_is_fake): + loss += self.loss_fake(logit_fake_is_fake) + self.loss_real(logit_real_is_fake) + + if self.normalize: + loss /= n_sub_adversaries + + self.optimizer.zero_grad() + with flashy.distrib.eager_sync_model(self.adversary): + loss.backward() + self.optimizer.step() + + return loss + + def forward(self, fake: torch.Tensor, real: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]: + """Return the loss for the generator, i.e. trying to fool the adversary, + and feature matching loss if provided. + """ + adv = torch.tensor(0., device=fake.device) + feat = torch.tensor(0., device=fake.device) + with flashy.utils.readonly(self.adversary): + all_logits_fake_is_fake, all_fmap_fake = self.get_adversary_pred(fake) + all_logits_real_is_fake, all_fmap_real = self.get_adversary_pred(real) + n_sub_adversaries = len(all_logits_fake_is_fake) + for logit_fake_is_fake in all_logits_fake_is_fake: + adv += self.loss(logit_fake_is_fake) + if self.loss_feat: + for fmap_fake, fmap_real in zip(all_fmap_fake, all_fmap_real): + feat += self.loss_feat(fmap_fake, fmap_real) + + if self.normalize: + adv /= n_sub_adversaries + feat /= n_sub_adversaries + + return adv, feat + + +def get_adv_criterion(loss_type: str) -> tp.Callable: + assert loss_type in ADVERSARIAL_LOSSES + if loss_type == 'mse': + return mse_loss + elif loss_type == 'hinge': + return hinge_loss + elif loss_type == 'hinge2': + return hinge2_loss + raise ValueError('Unsupported loss') + + +def get_fake_criterion(loss_type: str) -> tp.Callable: + assert loss_type in ADVERSARIAL_LOSSES + if loss_type == 'mse': + return mse_fake_loss + elif loss_type in ['hinge', 'hinge2']: + return hinge_fake_loss + raise ValueError('Unsupported loss') + + +def get_real_criterion(loss_type: str) -> tp.Callable: + assert loss_type in ADVERSARIAL_LOSSES + if loss_type == 'mse': + return mse_real_loss + elif loss_type in ['hinge', 'hinge2']: + return hinge_real_loss + raise ValueError('Unsupported loss') + + +def mse_real_loss(x: torch.Tensor) -> torch.Tensor: + return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x)) + + +def mse_fake_loss(x: torch.Tensor) -> torch.Tensor: + return F.mse_loss(x, torch.tensor(0., device=x.device).expand_as(x)) + + +def hinge_real_loss(x: torch.Tensor) -> torch.Tensor: + return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x))) + + +def hinge_fake_loss(x: torch.Tensor) -> torch.Tensor: + return -torch.mean(torch.min(-x - 1, torch.tensor(0., device=x.device).expand_as(x))) + + +def mse_loss(x: torch.Tensor) -> torch.Tensor: + if x.numel() == 0: + return torch.tensor([0.0], device=x.device) + return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x)) + + +def hinge_loss(x: torch.Tensor) -> torch.Tensor: + if x.numel() == 0: + return torch.tensor([0.0], device=x.device) + return -x.mean() + + +def hinge2_loss(x: torch.Tensor) -> torch.Tensor: + if x.numel() == 0: + return torch.tensor([0.0]) + return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x))) + + +class FeatureMatchingLoss(nn.Module): + """Feature matching loss for adversarial training. + + Args: + loss (nn.Module): Loss to use for feature matching (default=torch.nn.L1). + normalize (bool): Whether to normalize the loss. + by number of feature maps. + """ + def __init__(self, loss: nn.Module = torch.nn.L1Loss(), normalize: bool = True): + super().__init__() + self.loss = loss + self.normalize = normalize + + def forward(self, fmap_fake: tp.List[torch.Tensor], fmap_real: tp.List[torch.Tensor]) -> torch.Tensor: + assert len(fmap_fake) == len(fmap_real) and len(fmap_fake) > 0 + feat_loss = torch.tensor(0., device=fmap_fake[0].device) + feat_scale = torch.tensor(0., device=fmap_fake[0].device) + n_fmaps = 0 + for (feat_fake, feat_real) in zip(fmap_fake, fmap_real): + assert feat_fake.shape == feat_real.shape + n_fmaps += 1 + feat_loss += self.loss(feat_fake, feat_real) + feat_scale += torch.mean(torch.abs(feat_real)) + + if self.normalize: + feat_loss /= n_fmaps + + return feat_loss diff --git a/audiocraft/data/__init__.py b/audiocraft/data/__init__.py index 708a3dcead8dda89374a021177481dacae9f7fe9..2906ff12bc85a894837579f3137f6f71a0438329 100644 --- a/audiocraft/data/__init__.py +++ b/audiocraft/data/__init__.py @@ -3,6 +3,8 @@ # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +"""Audio loading and writing support. Datasets for raw audio +or also including some metadata.""" # flake8: noqa -from . import audio, audio_dataset +from . import audio, audio_dataset, info_audio_dataset, music_dataset, sound_dataset diff --git a/audiocraft/data/audio.py b/audiocraft/data/audio.py index 2048df6f175d7303bcf5c7b931922fd297908ead..a35dfd9c5a74330cfc3e1ba68932fffcb7c29750 100644 --- a/audiocraft/data/audio.py +++ b/audiocraft/data/audio.py @@ -18,11 +18,11 @@ import numpy as np import soundfile import torch from torch.nn import functional as F -import torchaudio as ta import av +import subprocess as sp -from .audio_utils import f32_pcm, i16_pcm, normalize_audio +from .audio_utils import f32_pcm, normalize_audio _av_initialized = False @@ -78,7 +78,7 @@ def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: floa seek_time (float): Time at which to start reading in the file. duration (float): Duration to read from the file. If set to -1, the whole file is read. Returns: - Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate + tuple of torch.Tensor, int: Tuple containing audio data and sample rate """ _init_av() with av.open(str(filepath)) as af: @@ -123,7 +123,7 @@ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0., duration (float): Duration to read from the file. If set to -1, the whole file is read. pad (bool): Pad output audio if not reaching expected duration. Returns: - Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate. + tuple of torch.Tensor, int: Tuple containing audio data and sample rate. """ fp = Path(filepath) if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg @@ -136,12 +136,6 @@ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0., wav = torch.from_numpy(wav).t().contiguous() if len(wav.shape) == 1: wav = torch.unsqueeze(wav, 0) - elif ( - fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats() - and duration <= 0 and seek_time == 0 - ): - # Torchaudio is faster if we load an entire file at once. - wav, sr = ta.load(fp) else: wav, sr = _av_read(filepath, seek_time, duration) if pad and duration > 0: @@ -150,10 +144,22 @@ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0., return wav, sr +def _piping_to_ffmpeg(out_path: tp.Union[str, Path], wav: torch.Tensor, sample_rate: int, flags: tp.List[str]): + # ffmpeg is always installed and torchaudio is a bit unstable lately, so let's bypass it entirely. + assert wav.dim() == 2, wav.shape + command = [ + 'ffmpeg', + '-loglevel', 'error', + '-y', '-f', 'f32le', '-ar', str(sample_rate), '-ac', str(wav.shape[0]), + '-i', '-'] + flags + [str(out_path)] + input_ = f32_pcm(wav).t().detach().cpu().numpy().tobytes() + sp.run(command, input=input_, check=True) + + def audio_write(stem_name: tp.Union[str, Path], wav: torch.Tensor, sample_rate: int, - format: str = 'wav', mp3_rate: int = 320, normalize: bool = True, - strategy: str = 'peak', peak_clip_headroom_db: float = 1, + format: str = 'wav', mp3_rate: int = 320, ogg_rate: tp.Optional[int] = None, + normalize: bool = True, strategy: str = 'peak', peak_clip_headroom_db: float = 1, rms_headroom_db: float = 18, loudness_headroom_db: float = 14, loudness_compressor: bool = False, log_clipping: bool = True, make_parent_dir: bool = True, @@ -162,8 +168,11 @@ def audio_write(stem_name: tp.Union[str, Path], Args: stem_name (str or Path): Filename without extension which will be added automatically. - format (str): Either "wav" or "mp3". + wav (torch.Tensor): Audio data to save. + sample_rate (int): Sample rate of audio data. + format (str): Either "wav", "mp3", "ogg", or "flac". mp3_rate (int): kbps when using mp3s. + ogg_rate (int): kbps when using ogg/vorbis. If not provided, let ffmpeg decide for itself. normalize (bool): if `True` (default), normalizes according to the prescribed strategy (see after). If `False`, the strategy is only used in case clipping would happen. @@ -175,7 +184,7 @@ def audio_write(stem_name: tp.Union[str, Path], than the `peak_clip` one to avoid further clipping. loudness_headroom_db (float): Target loudness for loudness normalization. loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'. - when strategy is 'loudness'log_clipping (bool): If True, basic logging on stderr when clipping still + when strategy is 'loudness' log_clipping (bool): If True, basic logging on stderr when clipping still occurs despite strategy (only for 'rms'). make_parent_dir (bool): Make parent directory if it doesn't exist. Returns: @@ -188,16 +197,23 @@ def audio_write(stem_name: tp.Union[str, Path], raise ValueError("Input wav should be at most 2 dimension.") assert wav.isfinite().all() wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db, - rms_headroom_db, loudness_headroom_db, log_clipping=log_clipping, - sample_rate=sample_rate, stem_name=str(stem_name)) - kwargs: dict = {} + rms_headroom_db, loudness_headroom_db, loudness_compressor, + log_clipping=log_clipping, sample_rate=sample_rate, + stem_name=str(stem_name)) if format == 'mp3': suffix = '.mp3' - kwargs.update({"compression": mp3_rate}) + flags = ['-f', 'mp3', '-c:a', 'libmp3lame', '-b:a', f'{mp3_rate}k'] elif format == 'wav': - wav = i16_pcm(wav) suffix = '.wav' - kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16}) + flags = ['-f', 'wav', '-c:a', 'pcm_s16le'] + elif format == 'ogg': + suffix = '.ogg' + flags = ['-f', 'ogg', '-c:a', 'libvorbis'] + if ogg_rate is not None: + flags += ['-b:a', f'{ogg_rate}k'] + elif format == 'flac': + suffix = '.flac' + flags = ['-f', 'flac'] else: raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.") if not add_suffix: @@ -206,7 +222,7 @@ def audio_write(stem_name: tp.Union[str, Path], if make_parent_dir: path.parent.mkdir(exist_ok=True, parents=True) try: - ta.save(path, wav, sample_rate, **kwargs) + _piping_to_ffmpeg(path, wav, sample_rate, flags) except Exception: if path.exists(): # we do not want to leave half written files around. diff --git a/audiocraft/data/audio_dataset.py b/audiocraft/data/audio_dataset.py index cf21422ea0059cb2d6553f93e608b8f9fa0d3a50..9d7442526186b3712f5d4754f928a40ecd964174 100644 --- a/audiocraft/data/audio_dataset.py +++ b/audiocraft/data/audio_dataset.py @@ -3,12 +3,16 @@ # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. - +"""AudioDataset support. In order to handle a larger number of files +without having to scan again the folders, we precompute some metadata +(filename, sample rate, duration), and use that to efficiently sample audio segments. +""" import argparse import copy from concurrent.futures import ThreadPoolExecutor, Future from dataclasses import dataclass, fields from contextlib import ExitStack +from functools import lru_cache import gzip import json import logging @@ -81,9 +85,12 @@ class AudioMeta(BaseInfo): class SegmentInfo(BaseInfo): meta: AudioMeta seek_time: float - n_frames: int # actual number of frames without padding + # The following values are given once the audio is processed, e.g. + # at the target sample rate and target number of channels. + n_frames: int # actual number of frames without padding total_frames: int # total number of frames, padding included - sample_rate: int # actual sample rate + sample_rate: int # actual sample rate + channels: int # number of audio channels. DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a'] @@ -114,8 +121,8 @@ def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta: Args: m (AudioMeta): Audio meta to resolve. - fast (bool): If True, uses a really fast check for determining if a file is already absolute or not. - Only valid on Linux/Mac. + fast (bool): If True, uses a really fast check for determining if a file + is already absolute or not. Only valid on Linux/Mac. Returns: AudioMeta: Audio meta with resolved path. """ @@ -151,7 +158,7 @@ def find_audio_files(path: tp.Union[Path, str], progress (bool): Whether to log progress on audio files collection. workers (int): number of parallel workers, if 0, use only the current thread. Returns: - List[AudioMeta]: List of audio file path and its metadata. + list of AudioMeta: List of audio file path and its metadata. """ audio_files = [] futures: tp.List[Future] = [] @@ -203,7 +210,7 @@ def load_audio_meta(path: tp.Union[str, Path], resolve (bool): Whether to resolve the path from AudioMeta (default=True). fast (bool): activates some tricks to make things faster. Returns: - List[AudioMeta]: List of audio file path and its total duration. + list of AudioMeta: List of audio file path and its total duration. """ open_fn = gzip.open if str(path).lower().endswith('.gz') else open with open_fn(path, 'rb') as fp: # type: ignore @@ -250,9 +257,14 @@ class AudioDataset: allows to return a tuple containing the torch Tensor and additional metadata on the segment and the original audio meta. + Note that you can call `start_epoch(epoch)` in order to get + a deterministic "randomization" for `shuffle=True`. + For a given epoch and dataset index, this will always return the same extract. + You can get back some diversity by setting the `shuffle_seed` param. + Args: - meta (tp.List[AudioMeta]): List of audio files metadata. - segment_duration (float): Optional segment duration of audio to load. + meta (list of AudioMeta): List of audio files metadata. + segment_duration (float, optional): Optional segment duration of audio to load. If not specified, the dataset will load the full audio segment from the file. shuffle (bool): Set to `True` to have the data reshuffled at every epoch. sample_rate (int): Target sample rate of the loaded audio samples. @@ -266,10 +278,19 @@ class AudioDataset: is shorter than the desired segment. max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset. return_info (bool): Whether to return the wav only or return wav along with segment info and metadata. - min_audio_duration (tp.Optional[float], optional): Minimum audio file duration, in seconds, if provided + min_audio_duration (float, optional): Minimum audio file duration, in seconds, if provided audio shorter than this will be filtered out. - max_audio_duration (tp.Optional[float], optional): Maximal audio file duration in seconds, if provided + max_audio_duration (float, optional): Maximal audio file duration in seconds, if provided audio longer than this will be filtered out. + shuffle_seed (int): can be used to further randomize + load_wav (bool): if False, skip loading the wav but returns a tensor of 0 + with the expected segment_duration (which must be provided if load_wav is False). + permutation_on_files (bool): only if `sample_on_weight` and `sample_on_duration` + are False. Will ensure a permutation on files when going through the dataset. + In that case the epoch number must be provided in order for the model + to continue the permutation across epochs. In that case, it is assumed + that `num_samples = total_batch_size * num_updates_per_epoch`, with + `total_batch_size` the overall batch size accounting for all gpus. """ def __init__(self, meta: tp.List[AudioMeta], @@ -285,16 +306,14 @@ class AudioDataset: max_read_retry: int = 10, return_info: bool = False, min_audio_duration: tp.Optional[float] = None, - max_audio_duration: tp.Optional[float] = None + max_audio_duration: tp.Optional[float] = None, + shuffle_seed: int = 0, + load_wav: bool = True, + permutation_on_files: bool = False, ): - assert len(meta) > 0, 'No audio meta provided to AudioDataset. Please check loading of audio meta.' + assert len(meta) > 0, "No audio meta provided to AudioDataset. Please check loading of audio meta." assert segment_duration is None or segment_duration > 0 assert segment_duration is None or min_segment_ratio >= 0 - logging.debug(f'sample_on_duration: {sample_on_duration}') - logging.debug(f'sample_on_weight: {sample_on_weight}') - logging.debug(f'pad: {pad}') - logging.debug(f'min_segment_ratio: {min_segment_ratio}') - self.segment_duration = segment_duration self.min_segment_ratio = min_segment_ratio self.max_audio_duration = max_audio_duration @@ -317,13 +336,25 @@ class AudioDataset: self.sampling_probabilities = self._get_sampling_probabilities() self.max_read_retry = max_read_retry self.return_info = return_info + self.shuffle_seed = shuffle_seed + self.current_epoch: tp.Optional[int] = None + self.load_wav = load_wav + if not load_wav: + assert segment_duration is not None + self.permutation_on_files = permutation_on_files + if permutation_on_files: + assert not self.sample_on_duration + assert not self.sample_on_weight + assert self.shuffle + + def start_epoch(self, epoch: int): + self.current_epoch = epoch def __len__(self): return self.num_samples def _get_sampling_probabilities(self, normalized: bool = True): - """Return the sampling probabilities for each file inside `self.meta`. - """ + """Return the sampling probabilities for each file inside `self.meta`.""" scores: tp.List[float] = [] for file_meta in self.meta: score = 1. @@ -337,12 +368,32 @@ class AudioDataset: probabilities /= probabilities.sum() return probabilities - def sample_file(self, rng: torch.Generator) -> AudioMeta: - """Sample a given file from `self.meta`. Can be overriden in subclasses. + @staticmethod + @lru_cache(16) + def _get_file_permutation(num_files: int, permutation_index: int, base_seed: int): + # Used to keep the most recent files permutation in memory implicitely. + # will work unless someone is using a lot of Datasets in parallel. + rng = torch.Generator() + rng.manual_seed(base_seed + permutation_index) + return torch.randperm(num_files, generator=rng) + + def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta: + """Sample a given file from `self.meta`. Can be overridden in subclasses. This is only called if `segment_duration` is not None. You must use the provided random number generator `rng` for reproducibility. + You can further make use of the index accessed. """ + if self.permutation_on_files: + assert self.current_epoch is not None + total_index = self.current_epoch * len(self) + index + permutation_index = total_index // len(self.meta) + relative_index = total_index % len(self.meta) + permutation = AudioDataset._get_file_permutation( + len(self.meta), permutation_index, self.shuffle_seed) + file_index = permutation[relative_index] + return self.meta[file_index] + if not self.sample_on_weight and not self.sample_on_duration: file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item()) else: @@ -350,6 +401,15 @@ class AudioDataset: return self.meta[file_index] + def _audio_read(self, path: str, seek_time: float = 0, duration: float = -1): + # Override this method in subclass if needed. + if self.load_wav: + return audio_read(path, seek_time, duration, pad=False) + else: + assert self.segment_duration is not None + n_frames = int(self.sample_rate * self.segment_duration) + return torch.zeros(self.channels, n_frames), self.sample_rate + def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]: if self.segment_duration is None: file_meta = self.meta[index] @@ -357,18 +417,22 @@ class AudioDataset: out = convert_audio(out, sr, self.sample_rate, self.channels) n_frames = out.shape[-1] segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames, - sample_rate=self.sample_rate) + sample_rate=self.sample_rate, channels=out.shape[0]) else: rng = torch.Generator() if self.shuffle: - # We use index, plus extra randomness - rng.manual_seed(index + self.num_samples * random.randint(0, 2**24)) + # We use index, plus extra randomness, either totally random if we don't know the epoch. + # otherwise we make use of the epoch number and optional shuffle_seed. + if self.current_epoch is None: + rng.manual_seed(index + self.num_samples * random.randint(0, 2**24)) + else: + rng.manual_seed(index + self.num_samples * (self.current_epoch + self.shuffle_seed)) else: # We only use index rng.manual_seed(index) for retry in range(self.max_read_retry): - file_meta = self.sample_file(rng) + file_meta = self.sample_file(index, rng) # We add some variance in the file position even if audio file is smaller than segment # without ending up with empty segments max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio) @@ -381,7 +445,7 @@ class AudioDataset: if self.pad: out = F.pad(out, (0, target_frames - n_frames)) segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames, - sample_rate=self.sample_rate) + sample_rate=self.sample_rate, channels=out.shape[0]) except Exception as exc: logger.warning("Error opening file %s: %r", file_meta.path, exc) if retry == self.max_read_retry - 1: @@ -423,7 +487,7 @@ class AudioDataset: if to_pad: # Each wav could be of a different duration as they are not segmented. for i in range(len(samples)): - # Determines the total legth of the signal with padding, so we update here as we pad. + # Determines the total length of the signal with padding, so we update here as we pad. segment_infos[i].total_frames = max_len wavs[i] = _pad_wav(wavs[i]) @@ -436,9 +500,7 @@ class AudioDataset: return torch.stack(samples) def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]: - """Filters out audio files with short durations. - Removes from meta files that have durations that will not allow to samples examples from them. - """ + """Filters out audio files with audio durations that will not allow to sample examples from them.""" orig_len = len(meta) # Filter data that is too short. diff --git a/audiocraft/data/audio_utils.py b/audiocraft/data/audio_utils.py index 76d4bc2a33ce722d879db2af33cd1336bd6b1fb3..9d3129b84b114c5572078295604279884c79f2cc 100644 --- a/audiocraft/data/audio_utils.py +++ b/audiocraft/data/audio_utils.py @@ -3,7 +3,8 @@ # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. - +"""Various utilities for audio convertion (pcm format, sample rate and channels), +and volume normalization.""" import sys import typing as tp @@ -47,8 +48,7 @@ def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor def convert_audio(wav: torch.Tensor, from_rate: float, to_rate: float, to_channels: int) -> torch.Tensor: - """Convert audio to new sample rate and number of audio channels. - """ + """Convert audio to new sample rate and number of audio channels.""" wav = julius.resample_frac(wav, int(from_rate), int(to_rate)) wav = convert_audio_channels(wav, to_channels) return wav @@ -66,7 +66,7 @@ def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db loudness_compressor (bool): Uses tanh for soft clipping. energy_floor (float): anything below that RMS level will not be rescaled. Returns: - output (torch.Tensor): Loudness normalized output data. + torch.Tensor: Loudness normalized output data. """ energy = wav.pow(2).mean().sqrt().item() if energy < energy_floor: @@ -117,7 +117,7 @@ def normalize_audio(wav: torch.Tensor, normalize: bool = True, log_clipping (bool): If True, basic logging on stderr when clipping still occurs despite strategy (only for 'rms'). sample_rate (int): Sample rate for the audio data (required for loudness). - stem_name (Optional[str]): Stem name for clipping logging. + stem_name (str, optional): Stem name for clipping logging. Returns: torch.Tensor: Normalized audio. """ @@ -150,17 +150,19 @@ def f32_pcm(wav: torch.Tensor) -> torch.Tensor: """ if wav.dtype.is_floating_point: return wav - else: - assert wav.dtype == torch.int16 + elif wav.dtype == torch.int16: return wav.float() / 2**15 + elif wav.dtype == torch.int32: + return wav.float() / 2**31 + raise ValueError(f"Unsupported wav dtype: {wav.dtype}") def i16_pcm(wav: torch.Tensor) -> torch.Tensor: """Convert audio to int 16 bits PCM format. - ..Warning:: There exist many formula for doing this convertion. None are perfect - due to the asymetry of the int16 range. One either have possible clipping, DC offset, - or inconsistancies with f32_pcm. If the given wav doesn't have enough headroom, + ..Warning:: There exist many formula for doing this conversion. None are perfect + due to the asymmetry of the int16 range. One either have possible clipping, DC offset, + or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom, it is possible that `i16_pcm(f32_pcm)) != Identity`. """ if wav.dtype.is_floating_point: diff --git a/audiocraft/data/info_audio_dataset.py b/audiocraft/data/info_audio_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..47ab4b1594faf1e9f1ce962fb980d80295b1f079 --- /dev/null +++ b/audiocraft/data/info_audio_dataset.py @@ -0,0 +1,110 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Base classes for the datasets that also provide non-audio metadata, +e.g. description, text transcription etc. +""" +from dataclasses import dataclass +import logging +import math +import re +import typing as tp + +import torch + +from .audio_dataset import AudioDataset, AudioMeta +from ..environment import AudioCraftEnvironment +from ..modules.conditioners import SegmentWithAttributes, ConditioningAttributes + + +logger = logging.getLogger(__name__) + + +def _clusterify_meta(meta: AudioMeta) -> AudioMeta: + """Monkey-patch meta to match cluster specificities.""" + meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path) + if meta.info_path is not None: + meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path) + return meta + + +def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]: + """Monkey-patch all meta to match cluster specificities.""" + return [_clusterify_meta(m) for m in meta] + + +@dataclass +class AudioInfo(SegmentWithAttributes): + """Dummy SegmentInfo with empty attributes. + + The InfoAudioDataset is expected to return metadata that inherits + from SegmentWithAttributes class and can return conditioning attributes. + + This basically guarantees all datasets will be compatible with current + solver that contain conditioners requiring this. + """ + audio_tokens: tp.Optional[torch.Tensor] = None # populated when using cached batch for training a LM. + + def to_condition_attributes(self) -> ConditioningAttributes: + return ConditioningAttributes() + + +class InfoAudioDataset(AudioDataset): + """AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform. + + See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments. + """ + def __init__(self, meta: tp.List[AudioMeta], **kwargs): + super().__init__(clusterify_all_meta(meta), **kwargs) + + def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]: + if not self.return_info: + wav = super().__getitem__(index) + assert isinstance(wav, torch.Tensor) + return wav + wav, meta = super().__getitem__(index) + return wav, AudioInfo(**meta.to_dict()) + + +def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]: + """Preprocess a single keyword or possible a list of keywords.""" + if isinstance(value, list): + return get_keyword_list(value) + else: + return get_keyword(value) + + +def get_string(value: tp.Optional[str]) -> tp.Optional[str]: + """Preprocess a single keyword.""" + if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None': + return None + else: + return value.strip() + + +def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]: + """Preprocess a single keyword.""" + if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None': + return None + else: + return value.strip().lower() + + +def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]: + """Preprocess a list of keywords.""" + if isinstance(values, str): + values = [v.strip() for v in re.split(r'[,\s]', values)] + elif isinstance(values, float) and math.isnan(values): + values = [] + if not isinstance(values, list): + logger.debug(f"Unexpected keyword list {values}") + values = [str(values)] + + kws = [get_keyword(v) for v in values] + kw_list = [k for k in kws if k is not None] + if len(kw_list) == 0: + return None + else: + return kw_list diff --git a/audiocraft/data/music_dataset.py b/audiocraft/data/music_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..4e28796939f9cde2b23a2c4bf43fd7ba5fa26b2d --- /dev/null +++ b/audiocraft/data/music_dataset.py @@ -0,0 +1,270 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Dataset of music tracks with rich metadata. +""" +from dataclasses import dataclass, field, fields, replace +import gzip +import json +import logging +from pathlib import Path +import random +import typing as tp + +import torch + +from .info_audio_dataset import ( + InfoAudioDataset, + AudioInfo, + get_keyword_list, + get_keyword, + get_string +) +from ..modules.conditioners import ( + ConditioningAttributes, + JointEmbedCondition, + WavCondition, +) +from ..utils.utils import warn_once + + +logger = logging.getLogger(__name__) + + +@dataclass +class MusicInfo(AudioInfo): + """Segment info augmented with music metadata. + """ + # music-specific metadata + title: tp.Optional[str] = None + artist: tp.Optional[str] = None # anonymized artist id, used to ensure no overlap between splits + key: tp.Optional[str] = None + bpm: tp.Optional[float] = None + genre: tp.Optional[str] = None + moods: tp.Optional[list] = None + keywords: tp.Optional[list] = None + description: tp.Optional[str] = None + name: tp.Optional[str] = None + instrument: tp.Optional[str] = None + # original wav accompanying the metadata + self_wav: tp.Optional[WavCondition] = None + # dict mapping attributes names to tuple of wav, text and metadata + joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict) + + @property + def has_music_meta(self) -> bool: + return self.name is not None + + def to_condition_attributes(self) -> ConditioningAttributes: + out = ConditioningAttributes() + for _field in fields(self): + key, value = _field.name, getattr(self, _field.name) + if key == 'self_wav': + out.wav[key] = value + elif key == 'joint_embed': + for embed_attribute, embed_cond in value.items(): + out.joint_embed[embed_attribute] = embed_cond + else: + if isinstance(value, list): + value = ' '.join(value) + out.text[key] = value + return out + + @staticmethod + def attribute_getter(attribute): + if attribute == 'bpm': + preprocess_func = get_bpm + elif attribute == 'key': + preprocess_func = get_musical_key + elif attribute in ['moods', 'keywords']: + preprocess_func = get_keyword_list + elif attribute in ['genre', 'name', 'instrument']: + preprocess_func = get_keyword + elif attribute in ['title', 'artist', 'description']: + preprocess_func = get_string + else: + preprocess_func = None + return preprocess_func + + @classmethod + def from_dict(cls, dictionary: dict, fields_required: bool = False): + _dictionary: tp.Dict[str, tp.Any] = {} + + # allow a subset of attributes to not be loaded from the dictionary + # these attributes may be populated later + post_init_attributes = ['self_wav', 'joint_embed'] + optional_fields = ['keywords'] + + for _field in fields(cls): + if _field.name in post_init_attributes: + continue + elif _field.name not in dictionary: + if fields_required and _field.name not in optional_fields: + raise KeyError(f"Unexpected missing key: {_field.name}") + else: + preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name) + value = dictionary[_field.name] + if preprocess_func: + value = preprocess_func(value) + _dictionary[_field.name] = value + return cls(**_dictionary) + + +def augment_music_info_description(music_info: MusicInfo, merge_text_p: float = 0., + drop_desc_p: float = 0., drop_other_p: float = 0.) -> MusicInfo: + """Augment MusicInfo description with additional metadata fields and potential dropout. + Additional textual attributes are added given probability 'merge_text_conditions_p' and + the original textual description is dropped from the augmented description given probability drop_desc_p. + + Args: + music_info (MusicInfo): The music metadata to augment. + merge_text_p (float): Probability of merging additional metadata to the description. + If provided value is 0, then no merging is performed. + drop_desc_p (float): Probability of dropping the original description on text merge. + if provided value is 0, then no drop out is performed. + drop_other_p (float): Probability of dropping the other fields used for text augmentation. + Returns: + MusicInfo: The MusicInfo with augmented textual description. + """ + def is_valid_field(field_name: str, field_value: tp.Any) -> bool: + valid_field_name = field_name in ['key', 'bpm', 'genre', 'moods', 'instrument', 'keywords'] + valid_field_value = field_value is not None and isinstance(field_value, (int, float, str, list)) + keep_field = random.uniform(0, 1) < drop_other_p + return valid_field_name and valid_field_value and keep_field + + def process_value(v: tp.Any) -> str: + if isinstance(v, (int, float, str)): + return str(v) + if isinstance(v, list): + return ", ".join(v) + else: + raise ValueError(f"Unknown type for text value! ({type(v), v})") + + description = music_info.description + + metadata_text = "" + if random.uniform(0, 1) < merge_text_p: + meta_pairs = [f'{_field.name}: {process_value(getattr(music_info, _field.name))}' + for _field in fields(music_info) if is_valid_field(_field.name, getattr(music_info, _field.name))] + random.shuffle(meta_pairs) + metadata_text = ". ".join(meta_pairs) + description = description if not random.uniform(0, 1) < drop_desc_p else None + logger.debug(f"Applying text augmentation on MMI info. description: {description}, metadata: {metadata_text}") + + if description is None: + description = metadata_text if len(metadata_text) > 1 else None + else: + description = ". ".join([description.rstrip('.'), metadata_text]) + description = description.strip() if description else None + + music_info = replace(music_info) + music_info.description = description + return music_info + + +class Paraphraser: + def __init__(self, paraphrase_source: tp.Union[str, Path], paraphrase_p: float = 0.): + self.paraphrase_p = paraphrase_p + open_fn = gzip.open if str(paraphrase_source).lower().endswith('.gz') else open + with open_fn(paraphrase_source, 'rb') as f: # type: ignore + self.paraphrase_source = json.loads(f.read()) + logger.info(f"loaded paraphrasing source from: {paraphrase_source}") + + def sample_paraphrase(self, audio_path: str, description: str): + if random.random() >= self.paraphrase_p: + return description + info_path = Path(audio_path).with_suffix('.json') + if info_path not in self.paraphrase_source: + warn_once(logger, f"{info_path} not in paraphrase source!") + return description + new_desc = random.choice(self.paraphrase_source[info_path]) + logger.debug(f"{description} -> {new_desc}") + return new_desc + + +class MusicDataset(InfoAudioDataset): + """Music dataset is an AudioDataset with music-related metadata. + + Args: + info_fields_required (bool): Whether to enforce having required fields. + merge_text_p (float): Probability of merging additional metadata to the description. + drop_desc_p (float): Probability of dropping the original description on text merge. + drop_other_p (float): Probability of dropping the other fields used for text augmentation. + joint_embed_attributes (list[str]): A list of attributes for which joint embedding metadata is returned. + paraphrase_source (str, optional): Path to the .json or .json.gz file containing the + paraphrases for the description. The json should be a dict with keys are the + original info path (e.g. track_path.json) and each value is a list of possible + paraphrased. + paraphrase_p (float): probability of taking a paraphrase. + + See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments. + """ + def __init__(self, *args, info_fields_required: bool = True, + merge_text_p: float = 0., drop_desc_p: float = 0., drop_other_p: float = 0., + joint_embed_attributes: tp.List[str] = [], + paraphrase_source: tp.Optional[str] = None, paraphrase_p: float = 0, + **kwargs): + kwargs['return_info'] = True # We require the info for each song of the dataset. + super().__init__(*args, **kwargs) + self.info_fields_required = info_fields_required + self.merge_text_p = merge_text_p + self.drop_desc_p = drop_desc_p + self.drop_other_p = drop_other_p + self.joint_embed_attributes = joint_embed_attributes + self.paraphraser = None + if paraphrase_source is not None: + self.paraphraser = Paraphraser(paraphrase_source, paraphrase_p) + + def __getitem__(self, index): + wav, info = super().__getitem__(index) + info_data = info.to_dict() + music_info_path = Path(info.meta.path).with_suffix('.json') + + if Path(music_info_path).exists(): + with open(music_info_path, 'r') as json_file: + music_data = json.load(json_file) + music_data.update(info_data) + music_info = MusicInfo.from_dict(music_data, fields_required=self.info_fields_required) + if self.paraphraser is not None: + music_info.description = self.paraphraser.sample(music_info.meta.path, music_info.description) + if self.merge_text_p: + music_info = augment_music_info_description( + music_info, self.merge_text_p, self.drop_desc_p, self.drop_other_p) + else: + music_info = MusicInfo.from_dict(info_data, fields_required=False) + + music_info.self_wav = WavCondition( + wav=wav[None], length=torch.tensor([info.n_frames]), + sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time]) + + for att in self.joint_embed_attributes: + att_value = getattr(music_info, att) + joint_embed_cond = JointEmbedCondition( + wav[None], [att_value], torch.tensor([info.n_frames]), + sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time]) + music_info.joint_embed[att] = joint_embed_cond + + return wav, music_info + + +def get_musical_key(value: tp.Optional[str]) -> tp.Optional[str]: + """Preprocess key keywords, discarding them if there are multiple key defined.""" + if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None': + return None + elif ',' in value: + # For now, we discard when multiple keys are defined separated with comas + return None + else: + return value.strip().lower() + + +def get_bpm(value: tp.Optional[str]) -> tp.Optional[float]: + """Preprocess to a float.""" + if value is None: + return None + try: + return float(value) + except ValueError: + return None diff --git a/audiocraft/data/sound_dataset.py b/audiocraft/data/sound_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..8b88cbe8016b4bd28c2de749177c9af29f7755fc --- /dev/null +++ b/audiocraft/data/sound_dataset.py @@ -0,0 +1,330 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Dataset of audio with a simple description. +""" + +from dataclasses import dataclass, fields, replace +import json +from pathlib import Path +import random +import typing as tp + +import numpy as np +import torch + +from .info_audio_dataset import ( + InfoAudioDataset, + get_keyword_or_keyword_list +) +from ..modules.conditioners import ( + ConditioningAttributes, + SegmentWithAttributes, + WavCondition, +) + + +EPS = torch.finfo(torch.float32).eps +TARGET_LEVEL_LOWER = -35 +TARGET_LEVEL_UPPER = -15 + + +@dataclass +class SoundInfo(SegmentWithAttributes): + """Segment info augmented with Sound metadata. + """ + description: tp.Optional[str] = None + self_wav: tp.Optional[torch.Tensor] = None + + @property + def has_sound_meta(self) -> bool: + return self.description is not None + + def to_condition_attributes(self) -> ConditioningAttributes: + out = ConditioningAttributes() + + for _field in fields(self): + key, value = _field.name, getattr(self, _field.name) + if key == 'self_wav': + out.wav[key] = value + else: + out.text[key] = value + return out + + @staticmethod + def attribute_getter(attribute): + if attribute == 'description': + preprocess_func = get_keyword_or_keyword_list + else: + preprocess_func = None + return preprocess_func + + @classmethod + def from_dict(cls, dictionary: dict, fields_required: bool = False): + _dictionary: tp.Dict[str, tp.Any] = {} + + # allow a subset of attributes to not be loaded from the dictionary + # these attributes may be populated later + post_init_attributes = ['self_wav'] + + for _field in fields(cls): + if _field.name in post_init_attributes: + continue + elif _field.name not in dictionary: + if fields_required: + raise KeyError(f"Unexpected missing key: {_field.name}") + else: + preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name) + value = dictionary[_field.name] + if preprocess_func: + value = preprocess_func(value) + _dictionary[_field.name] = value + return cls(**_dictionary) + + +class SoundDataset(InfoAudioDataset): + """Sound audio dataset: Audio dataset with environmental sound-specific metadata. + + Args: + info_fields_required (bool): Whether all the mandatory metadata fields should be in the loaded metadata. + external_metadata_source (tp.Optional[str]): Folder containing JSON metadata for the corresponding dataset. + The metadata files contained in this folder are expected to match the stem of the audio file with + a json extension. + aug_p (float): Probability of performing audio mixing augmentation on the batch. + mix_p (float): Proportion of batch items that are mixed together when applying audio mixing augmentation. + mix_snr_low (int): Lowerbound for SNR value sampled for mixing augmentation. + mix_snr_high (int): Upperbound for SNR value sampled for mixing augmentation. + mix_min_overlap (float): Minimum overlap between audio files when performing mixing augmentation. + kwargs: Additional arguments for AudioDataset. + + See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments. + """ + def __init__( + self, + *args, + info_fields_required: bool = True, + external_metadata_source: tp.Optional[str] = None, + aug_p: float = 0., + mix_p: float = 0., + mix_snr_low: int = -5, + mix_snr_high: int = 5, + mix_min_overlap: float = 0.5, + **kwargs + ): + kwargs['return_info'] = True # We require the info for each song of the dataset. + super().__init__(*args, **kwargs) + self.info_fields_required = info_fields_required + self.external_metadata_source = external_metadata_source + self.aug_p = aug_p + self.mix_p = mix_p + if self.aug_p > 0: + assert self.mix_p > 0, "Expecting some mixing proportion mix_p if aug_p > 0" + assert self.channels == 1, "SoundDataset with audio mixing considers only monophonic audio" + self.mix_snr_low = mix_snr_low + self.mix_snr_high = mix_snr_high + self.mix_min_overlap = mix_min_overlap + + def _get_info_path(self, path: tp.Union[str, Path]) -> Path: + """Get path of JSON with metadata (description, etc.). + If there exists a JSON with the same name as 'path.name', then it will be used. + Else, such JSON will be searched for in an external json source folder if it exists. + """ + info_path = Path(path).with_suffix('.json') + if Path(info_path).exists(): + return info_path + elif self.external_metadata_source and (Path(self.external_metadata_source) / info_path.name).exists(): + return Path(self.external_metadata_source) / info_path.name + else: + raise Exception(f"Unable to find a metadata JSON for path: {path}") + + def __getitem__(self, index): + wav, info = super().__getitem__(index) + info_data = info.to_dict() + info_path = self._get_info_path(info.meta.path) + if Path(info_path).exists(): + with open(info_path, 'r') as json_file: + sound_data = json.load(json_file) + sound_data.update(info_data) + sound_info = SoundInfo.from_dict(sound_data, fields_required=self.info_fields_required) + # if there are multiple descriptions, sample one randomly + if isinstance(sound_info.description, list): + sound_info.description = random.choice(sound_info.description) + else: + sound_info = SoundInfo.from_dict(info_data, fields_required=False) + + sound_info.self_wav = WavCondition( + wav=wav[None], length=torch.tensor([info.n_frames]), + sample_rate=[sound_info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time]) + + return wav, sound_info + + def collater(self, samples): + # when training, audio mixing is performed in the collate function + wav, sound_info = super().collater(samples) # SoundDataset always returns infos + if self.aug_p > 0: + wav, sound_info = mix_samples(wav, sound_info, self.aug_p, self.mix_p, + snr_low=self.mix_snr_low, snr_high=self.mix_snr_high, + min_overlap=self.mix_min_overlap) + return wav, sound_info + + +def rms_f(x: torch.Tensor) -> torch.Tensor: + return (x ** 2).mean(1).pow(0.5) + + +def normalize(audio: torch.Tensor, target_level: int = -25) -> torch.Tensor: + """Normalize the signal to the target level.""" + rms = rms_f(audio) + scalar = 10 ** (target_level / 20) / (rms + EPS) + audio = audio * scalar.unsqueeze(1) + return audio + + +def is_clipped(audio: torch.Tensor, clipping_threshold: float = 0.99) -> torch.Tensor: + return (abs(audio) > clipping_threshold).any(1) + + +def mix_pair(src: torch.Tensor, dst: torch.Tensor, min_overlap: float) -> torch.Tensor: + start = random.randint(0, int(src.shape[1] * (1 - min_overlap))) + remainder = src.shape[1] - start + if dst.shape[1] > remainder: + src[:, start:] = src[:, start:] + dst[:, :remainder] + else: + src[:, start:start+dst.shape[1]] = src[:, start:start+dst.shape[1]] + dst + return src + + +def snr_mixer(clean: torch.Tensor, noise: torch.Tensor, snr: int, min_overlap: float, + target_level: int = -25, clipping_threshold: float = 0.99) -> torch.Tensor: + """Function to mix clean speech and noise at various SNR levels. + + Args: + clean (torch.Tensor): Clean audio source to mix, of shape [B, T]. + noise (torch.Tensor): Noise audio source to mix, of shape [B, T]. + snr (int): SNR level when mixing. + min_overlap (float): Minimum overlap between the two mixed sources. + target_level (int): Gain level in dB. + clipping_threshold (float): Threshold for clipping the audio. + Returns: + torch.Tensor: The mixed audio, of shape [B, T]. + """ + if clean.shape[1] > noise.shape[1]: + noise = torch.nn.functional.pad(noise, (0, clean.shape[1] - noise.shape[1])) + else: + noise = noise[:, :clean.shape[1]] + + # normalizing to -25 dB FS + clean = clean / (clean.max(1)[0].abs().unsqueeze(1) + EPS) + clean = normalize(clean, target_level) + rmsclean = rms_f(clean) + + noise = noise / (noise.max(1)[0].abs().unsqueeze(1) + EPS) + noise = normalize(noise, target_level) + rmsnoise = rms_f(noise) + + # set the noise level for a given SNR + noisescalar = (rmsclean / (10 ** (snr / 20)) / (rmsnoise + EPS)).unsqueeze(1) + noisenewlevel = noise * noisescalar + + # mix noise and clean speech + noisyspeech = mix_pair(clean, noisenewlevel, min_overlap) + + # randomly select RMS value between -15 dBFS and -35 dBFS and normalize noisyspeech with that value + # there is a chance of clipping that might happen with very less probability, which is not a major issue. + noisy_rms_level = np.random.randint(TARGET_LEVEL_LOWER, TARGET_LEVEL_UPPER) + rmsnoisy = rms_f(noisyspeech) + scalarnoisy = (10 ** (noisy_rms_level / 20) / (rmsnoisy + EPS)).unsqueeze(1) + noisyspeech = noisyspeech * scalarnoisy + clean = clean * scalarnoisy + noisenewlevel = noisenewlevel * scalarnoisy + + # final check to see if there are any amplitudes exceeding +/- 1. If so, normalize all the signals accordingly + clipped = is_clipped(noisyspeech) + if clipped.any(): + noisyspeech_maxamplevel = noisyspeech[clipped].max(1)[0].abs().unsqueeze(1) / (clipping_threshold - EPS) + noisyspeech[clipped] = noisyspeech[clipped] / noisyspeech_maxamplevel + + return noisyspeech + + +def snr_mix(src: torch.Tensor, dst: torch.Tensor, snr_low: int, snr_high: int, min_overlap: float): + if snr_low == snr_high: + snr = snr_low + else: + snr = np.random.randint(snr_low, snr_high) + mix = snr_mixer(src, dst, snr, min_overlap) + return mix + + +def mix_text(src_text: str, dst_text: str): + """Mix text from different sources by concatenating them.""" + if src_text == dst_text: + return src_text + return src_text + " " + dst_text + + +def mix_samples(wavs: torch.Tensor, infos: tp.List[SoundInfo], aug_p: float, mix_p: float, + snr_low: int, snr_high: int, min_overlap: float): + """Mix samples within a batch, summing the waveforms and concatenating the text infos. + + Args: + wavs (torch.Tensor): Audio tensors of shape [B, C, T]. + infos (list[SoundInfo]): List of SoundInfo items corresponding to the audio. + aug_p (float): Augmentation probability. + mix_p (float): Proportion of items in the batch to mix (and merge) together. + snr_low (int): Lowerbound for sampling SNR. + snr_high (int): Upperbound for sampling SNR. + min_overlap (float): Minimum overlap between mixed samples. + Returns: + tuple[torch.Tensor, list[SoundInfo]]: A tuple containing the mixed wavs + and mixed SoundInfo for the given batch. + """ + # no mixing to perform within the batch + if mix_p == 0: + return wavs, infos + + if random.uniform(0, 1) < aug_p: + # perform all augmentations on waveforms as [B, T] + # randomly picking pairs of audio to mix + assert wavs.size(1) == 1, f"Mix samples requires monophonic audio but C={wavs.size(1)}" + wavs = wavs.mean(dim=1, keepdim=False) + B, T = wavs.shape + k = int(mix_p * B) + mixed_sources_idx = torch.randperm(B)[:k] + mixed_targets_idx = torch.randperm(B)[:k] + aug_wavs = snr_mix( + wavs[mixed_sources_idx], + wavs[mixed_targets_idx], + snr_low, + snr_high, + min_overlap, + ) + # mixing textual descriptions in metadata + descriptions = [info.description for info in infos] + aug_infos = [] + for i, j in zip(mixed_sources_idx, mixed_targets_idx): + text = mix_text(descriptions[i], descriptions[j]) + m = replace(infos[i]) + m.description = text + aug_infos.append(m) + + # back to [B, C, T] + aug_wavs = aug_wavs.unsqueeze(1) + assert aug_wavs.shape[0] > 0, "Samples mixing returned empty batch." + assert aug_wavs.dim() == 3, f"Returned wav should be [B, C, T] but dim = {aug_wavs.dim()}" + assert aug_wavs.shape[0] == len(aug_infos), "Mismatch between number of wavs and infos in the batch" + + return aug_wavs, aug_infos # [B, C, T] + else: + # randomly pick samples in the batch to match + # the batch size when performing audio mixing + B, C, T = wavs.shape + k = int(mix_p * B) + wav_idx = torch.randperm(B)[:k] + wavs = wavs[wav_idx] + infos = [infos[i] for i in wav_idx] + assert wavs.shape[0] == len(infos), "Mismatch between number of wavs and infos in the batch" + + return wavs, infos # [B, C, T] diff --git a/audiocraft/data/zip.py b/audiocraft/data/zip.py index 1f1154231da321dd38d151ff285dbcff5e38a6e0..f0b17849d36991e7def35a14d3d518b9d867ce36 100644 --- a/audiocraft/data/zip.py +++ b/audiocraft/data/zip.py @@ -3,6 +3,8 @@ # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +"""Utility for reading some info from inside a zip file. +""" import typing import zipfile @@ -18,13 +20,13 @@ MODE = Literal['r', 'w', 'x', 'a'] @dataclass(order=True) class PathInZip: - """Class for holding a path of file within a zip file. + """Hold a path of file within a zip file. Args: - path: The convention is : + path (str): The convention is :. Let's assume there is a zip file /some/location/foo.zip and inside of it is a json file located at /data/file1.json, - Then we expect path = "/some/location/foo.zip:/data/file1.json" + Then we expect path = "/some/location/foo.zip:/data/file1.json". """ INFO_PATH_SEP = ':' @@ -55,7 +57,7 @@ def set_zip_cache_size(max_size: int): """Sets the maximal LRU caching for zip file opening. Args: - max_size: the maximal LRU cache. + max_size (int): the maximal LRU cache. """ global _cached_open_zip _cached_open_zip = lru_cache(max_size)(_open_zip) @@ -65,8 +67,8 @@ def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO: """Opens a file stored inside a zip and returns a file-like object. Args: - path_in_zip: A PathInZip object representing the file to return a file-like object of. - mode: The mode in which to open the file with. + path_in_zip (PathInZip): A PathInZip object representing the file to return a file-like object of. + mode (str): The mode in which to open the file with. Returns: A file-like object for PathInZip. """ diff --git a/audiocraft/environment.py b/audiocraft/environment.py new file mode 100644 index 0000000000000000000000000000000000000000..adc7819305758bb50a9984928bfa7f13eabef5f5 --- /dev/null +++ b/audiocraft/environment.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Provides cluster and tools configuration across clusters (slurm, dora, utilities). +""" + +import logging +import os +from pathlib import Path +import re +import typing as tp + +import omegaconf + +from .utils.cluster import _guess_cluster_type + + +logger = logging.getLogger(__name__) + + +class AudioCraftEnvironment: + """Environment configuration for teams and clusters. + + AudioCraftEnvironment picks compute cluster settings (slurm, dora) from the current running environment + or declared variable and the loaded team configuration. Additionally, the AudioCraftEnvironment + provides pointers to a reference folder resolved automatically across clusters that is shared across team members, + allowing to share sigs or other files to run jobs. Finally, it provides dataset mappers to automatically + map dataset file paths to new locations across clusters, allowing to use the same manifest of files across cluters. + + The cluster type is identified automatically and base configuration file is read from config/teams.yaml. + Use the following environment variables to specify the cluster, team or configuration: + + AUDIOCRAFT_CLUSTER (optional): Cluster type to enforce. Useful if the cluster type + cannot be inferred automatically. + AUDIOCRAFT_CONFIG (optional): Path to yaml config holding the teams configuration. + If not set, configuration is read from config/teams.yaml. + AUDIOCRAFT_TEAM (optional): Name of the team. Recommended to set to your own team. + Cluster configuration are shared across teams to match compute allocation, + specify your cluster configuration in the configuration file under a key mapping + your team name. + """ + _instance = None + DEFAULT_TEAM = "default" + + def __init__(self) -> None: + """Loads configuration.""" + self.team: str = os.getenv("AUDIOCRAFT_TEAM", self.DEFAULT_TEAM) + cluster_type = _guess_cluster_type() + cluster = os.getenv( + "AUDIOCRAFT_CLUSTER", cluster_type.value + ) + logger.info("Detecting cluster type %s", cluster_type) + + self.cluster: str = cluster + + config_path = os.getenv( + "AUDIOCRAFT_CONFIG", + Path(__file__) + .parent.parent.joinpath("config/teams", self.team) + .with_suffix(".yaml"), + ) + self.config = omegaconf.OmegaConf.load(config_path) + self._dataset_mappers = [] + cluster_config = self._get_cluster_config() + if "dataset_mappers" in cluster_config: + for pattern, repl in cluster_config["dataset_mappers"].items(): + regex = re.compile(pattern) + self._dataset_mappers.append((regex, repl)) + + def _get_cluster_config(self) -> omegaconf.DictConfig: + assert isinstance(self.config, omegaconf.DictConfig) + return self.config[self.cluster] + + @classmethod + def instance(cls): + if cls._instance is None: + cls._instance = cls() + return cls._instance + + @classmethod + def reset(cls): + """Clears the environment and forces a reload on next invocation.""" + cls._instance = None + + @classmethod + def get_team(cls) -> str: + """Gets the selected team as dictated by the AUDIOCRAFT_TEAM env var. + If not defined, defaults to "labs". + """ + return cls.instance().team + + @classmethod + def get_cluster(cls) -> str: + """Gets the detected cluster. + This value can be overridden by the AUDIOCRAFT_CLUSTER env var. + """ + return cls.instance().cluster + + @classmethod + def get_dora_dir(cls) -> Path: + """Gets the path to the dora directory for the current team and cluster. + Value is overridden by the AUDIOCRAFT_DORA_DIR env var. + """ + cluster_config = cls.instance()._get_cluster_config() + dora_dir = os.getenv("AUDIOCRAFT_DORA_DIR", cluster_config["dora_dir"]) + logger.warning(f"Dora directory: {dora_dir}") + return Path(dora_dir) + + @classmethod + def get_reference_dir(cls) -> Path: + """Gets the path to the reference directory for the current team and cluster. + Value is overridden by the AUDIOCRAFT_REFERENCE_DIR env var. + """ + cluster_config = cls.instance()._get_cluster_config() + return Path(os.getenv("AUDIOCRAFT_REFERENCE_DIR", cluster_config["reference_dir"])) + + @classmethod + def get_slurm_exclude(cls) -> tp.Optional[str]: + """Get the list of nodes to exclude for that cluster.""" + cluster_config = cls.instance()._get_cluster_config() + return cluster_config.get("slurm_exclude") + + @classmethod + def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str]] = None) -> str: + """Gets the requested partitions for the current team and cluster as a comma-separated string. + + Args: + partition_types (list[str], optional): partition types to retrieve. Values must be + from ['global', 'team']. If not provided, the global partition is returned. + """ + if not partition_types: + partition_types = ["global"] + + cluster_config = cls.instance()._get_cluster_config() + partitions = [ + cluster_config["partitions"][partition_type] + for partition_type in partition_types + ] + return ",".join(partitions) + + @classmethod + def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path: + """Converts reference placeholder in path with configured reference dir to resolve paths. + + Args: + path (str or Path): Path to resolve. + Returns: + Path: Resolved path. + """ + path = str(path) + + if path.startswith("//reference"): + reference_dir = cls.get_reference_dir() + logger.warn(f"Reference directory: {reference_dir}") + assert ( + reference_dir.exists() and reference_dir.is_dir() + ), f"Reference directory does not exist: {reference_dir}." + path = re.sub("^//reference", str(reference_dir), path) + + return Path(path) + + @classmethod + def apply_dataset_mappers(cls, path: str) -> str: + """Applies dataset mapping regex rules as defined in the configuration. + If no rules are defined, the path is returned as-is. + """ + instance = cls.instance() + + for pattern, repl in instance._dataset_mappers: + path = pattern.sub(repl, path) + + return path diff --git a/audiocraft/grids/__init__.py b/audiocraft/grids/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..70643517cd1a8b4e712eca90e23411ae89937795 --- /dev/null +++ b/audiocraft/grids/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Dora Grids.""" diff --git a/audiocraft/grids/_base_explorers.py b/audiocraft/grids/_base_explorers.py new file mode 100644 index 0000000000000000000000000000000000000000..d3f26666aa596f7bd2e8695c4f00e7963e978ceb --- /dev/null +++ b/audiocraft/grids/_base_explorers.py @@ -0,0 +1,80 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +import time +import typing as tp +from dora import Explorer +import treetable as tt + + +def get_sheep_ping(sheep) -> tp.Optional[str]: + """Return the amount of time since the Sheep made some update + to its log. Returns a str using the relevant time unit.""" + ping = None + if sheep.log is not None and sheep.log.exists(): + delta = time.time() - sheep.log.stat().st_mtime + if delta > 3600 * 24: + ping = f'{delta / (3600 * 24):.1f}d' + elif delta > 3600: + ping = f'{delta / (3600):.1f}h' + elif delta > 60: + ping = f'{delta / 60:.1f}m' + else: + ping = f'{delta:.1f}s' + return ping + + +class BaseExplorer(ABC, Explorer): + """Base explorer for AudioCraft grids. + + All task specific solvers are expected to implement the `get_grid_metrics` + method to specify logic about metrics to display for a given task. + + If additional stages are used, the child explorer must define how to handle + these new stages in the `process_history` and `process_sheep` methods. + """ + def stages(self): + return ["train", "valid", "evaluate"] + + def get_grid_meta(self): + """Returns the list of Meta information to display for each XP/job. + """ + return [ + tt.leaf("index", align=">"), + tt.leaf("name", wrap=140), + tt.leaf("state"), + tt.leaf("sig", align=">"), + tt.leaf("sid", align="<"), + ] + + @abstractmethod + def get_grid_metrics(self): + """Return the metrics that should be displayed in the tracking table. + """ + ... + + def process_sheep(self, sheep, history): + train = { + "epoch": len(history), + } + parts = {"train": train} + for metrics in history: + for key, sub in metrics.items(): + part = parts.get(key, {}) + if 'duration' in sub: + # Convert to minutes for readability. + sub['duration'] = sub['duration'] / 60. + part.update(sub) + parts[key] = part + ping = get_sheep_ping(sheep) + if ping is not None: + for name in self.stages(): + if name not in parts: + parts[name] = {} + # Add the ping to each part for convenience. + parts[name]['ping'] = ping + return parts diff --git a/audiocraft/grids/audiogen/__init__.py b/audiocraft/grids/audiogen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8a0a2688450ce120088b79c3314a2f267394dc11 --- /dev/null +++ b/audiocraft/grids/audiogen/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""AudioGen grids.""" diff --git a/audiocraft/grids/audiogen/audiogen_base_16khz.py b/audiocraft/grids/audiogen/audiogen_base_16khz.py new file mode 100644 index 0000000000000000000000000000000000000000..190cc1d0a1e316347e8ebbdfc8de7e2942c1b3d7 --- /dev/null +++ b/audiocraft/grids/audiogen/audiogen_base_16khz.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from ..musicgen._explorers import LMExplorer +from ...environment import AudioCraftEnvironment + + +@LMExplorer +def explorer(launcher): + partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) + launcher.slurm_(gpus=64, partition=partitions) + launcher.bind_(solver='audiogen/audiogen_base_16khz') + # replace this by the desired environmental sound dataset + launcher.bind_(dset='internal/sounds_16khz') + + fsdp = {'autocast': False, 'fsdp.use': True} + medium = {'model/lm/model_scale': 'medium'} + + launcher.bind_(fsdp) + launcher(medium) diff --git a/audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py b/audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..12f6d402a3c4a113d4c37be062790fa435b72104 --- /dev/null +++ b/audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Evaluation with objective metrics for the pretrained AudioGen models. +This grid takes signature from the training grid and runs evaluation-only stage. + +When running the grid for the first time, please use: +REGEN=1 dora grid audiogen.audiogen_pretrained_16khz_eval +and re-use the REGEN=1 option when the grid is changed to force regenerating it. + +Note that you need the proper metrics external libraries setup to use all +the objective metrics activated in this grid. Refer to the README for more information. +""" + +import os + +from ..musicgen._explorers import GenerationEvalExplorer +from ...environment import AudioCraftEnvironment +from ... import train + + +def eval(launcher, batch_size: int = 32): + opts = { + 'dset': 'audio/audiocaps_16khz', + 'solver/audiogen/evaluation': 'objective_eval', + 'execute_only': 'evaluate', + '+dataset.evaluate.batch_size': batch_size, + '+metrics.fad.tf.batch_size': 32, + } + # binary for FAD computation: replace this path with your own path + metrics_opts = { + 'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research' + } + opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.} + opt2 = {'transformer_lm.two_step_cfg': True} + + sub = launcher.bind(opts) + sub.bind_(metrics_opts) + + # base objective metrics + sub(opt1, opt2) + + +@GenerationEvalExplorer +def explorer(launcher): + partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) + launcher.slurm_(gpus=4, partition=partitions) + + if 'REGEN' not in os.environ: + folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1] + with launcher.job_array(): + for sig in folder.iterdir(): + if not sig.is_symlink(): + continue + xp = train.main.get_xp_from_sig(sig.name) + launcher(xp.argv) + return + + audiogen_base = launcher.bind(solver="audiogen/audiogen_base_16khz") + audiogen_base.bind_({'autocast': False, 'fsdp.use': True}) + + audiogen_base_medium = audiogen_base.bind({'continue_from': '//pretrained/facebook/audiogen-medium'}) + audiogen_base_medium.bind_({'model/lm/model_scale': 'medium'}) + eval(audiogen_base_medium, batch_size=128) diff --git a/audiocraft/grids/compression/__init__.py b/audiocraft/grids/compression/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5b688528f1f3e4efc0c2a1e9d490f33c4158b3f0 --- /dev/null +++ b/audiocraft/grids/compression/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""EnCodec grids.""" diff --git a/audiocraft/grids/compression/_explorers.py b/audiocraft/grids/compression/_explorers.py new file mode 100644 index 0000000000000000000000000000000000000000..eed30d5b8a1c14676503148ddf133c79ed2e33bf --- /dev/null +++ b/audiocraft/grids/compression/_explorers.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import treetable as tt + +from .._base_explorers import BaseExplorer + + +class CompressionExplorer(BaseExplorer): + eval_metrics = ["sisnr", "visqol"] + + def stages(self): + return ["train", "valid", "evaluate"] + + def get_grid_meta(self): + """Returns the list of Meta information to display for each XP/job. + """ + return [ + tt.leaf("index", align=">"), + tt.leaf("name", wrap=140), + tt.leaf("state"), + tt.leaf("sig", align=">"), + ] + + def get_grid_metrics(self): + """Return the metrics that should be displayed in the tracking table. + """ + return [ + tt.group( + "train", + [ + tt.leaf("epoch"), + tt.leaf("bandwidth", ".2f"), + tt.leaf("adv", ".4f"), + tt.leaf("d_loss", ".4f"), + ], + align=">", + ), + tt.group( + "valid", + [ + tt.leaf("bandwidth", ".2f"), + tt.leaf("adv", ".4f"), + tt.leaf("msspec", ".4f"), + tt.leaf("sisnr", ".2f"), + ], + align=">", + ), + tt.group( + "evaluate", [tt.leaf(name, ".3f") for name in self.eval_metrics], align=">" + ), + ] diff --git a/audiocraft/grids/compression/debug.py b/audiocraft/grids/compression/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..5612ff5688d85fede0e605b244919e8081cb1da9 --- /dev/null +++ b/audiocraft/grids/compression/debug.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Grid search file, simply list all the exp you want in `explorer`. +Any new exp added there will be scheduled. +You can cancel and experiment by commenting its line. + +This grid is a minimal example for debugging compression task +and how to override parameters directly in a grid. +Learn more about dora grids: https://github.com/facebookresearch/dora +""" + +from ._explorers import CompressionExplorer +from ...environment import AudioCraftEnvironment + + +@CompressionExplorer +def explorer(launcher): + partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) + launcher.slurm_(gpus=2, partition=partitions) + launcher.bind_(solver='compression/debug') + + with launcher.job_array(): + # base debug task using config from solver=compression/debug + launcher() + # we can override parameters in the grid to launch additional xps + launcher({'rvq.bins': 2048, 'rvq.n_q': 4}) diff --git a/audiocraft/grids/compression/encodec_audiogen_16khz.py b/audiocraft/grids/compression/encodec_audiogen_16khz.py new file mode 100644 index 0000000000000000000000000000000000000000..c9b41f684045594bb264cfb7f4f15d1da439382c --- /dev/null +++ b/audiocraft/grids/compression/encodec_audiogen_16khz.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Grid search file, simply list all the exp you want in `explorer`. +Any new exp added there will be scheduled. +You can cancel and experiment by commenting its line. + +This grid shows how to train the new AudioGen EnCodec model at 16 kHz. +""" + +from ._explorers import CompressionExplorer +from ...environment import AudioCraftEnvironment + + +@CompressionExplorer +def explorer(launcher): + partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) + launcher.slurm_(gpus=8, partition=partitions) + # use configuration for AudioGen's EnCodec model trained on monophonic audio sampled at 16 kHz + # AudioGen's EnCodec is trained with a total stride of 320 leading to a frame rate of 50 hz + launcher.bind_(solver='compression/encodec_audiogen_16khz') + # replace this by the desired sound dataset + launcher.bind_(dset='internal/sounds_16khz') + # launch xp + launcher() diff --git a/audiocraft/grids/compression/encodec_base_24khz.py b/audiocraft/grids/compression/encodec_base_24khz.py new file mode 100644 index 0000000000000000000000000000000000000000..117b2b1e496ca31b3d614672b472c9213cedb4ad --- /dev/null +++ b/audiocraft/grids/compression/encodec_base_24khz.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Grid search file, simply list all the exp you want in `explorer`. +Any new exp added there will be scheduled. +You can cancel and experiment by commenting its line. + +This grid shows how to train a base causal EnCodec model at 24 kHz. +""" + +from ._explorers import CompressionExplorer +from ...environment import AudioCraftEnvironment + + +@CompressionExplorer +def explorer(launcher): + partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) + launcher.slurm_(gpus=8, partition=partitions) + # base causal EnCodec trained on monophonic audio sampled at 24 kHz + launcher.bind_(solver='compression/encodec_base_24khz') + # replace this by the desired dataset + launcher.bind_(dset='audio/example') + # launch xp + launcher() diff --git a/audiocraft/grids/compression/encodec_musicgen_32khz.py b/audiocraft/grids/compression/encodec_musicgen_32khz.py new file mode 100644 index 0000000000000000000000000000000000000000..9da31daa5f009f46e753601a51a06391594b8f9b --- /dev/null +++ b/audiocraft/grids/compression/encodec_musicgen_32khz.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Grid search file, simply list all the exp you want in `explorer`. +Any new exp added there will be scheduled. +You can cancel and experiment by commenting its line. + +This grid shows how to train a MusicGen EnCodec model at 32 kHz. +""" + +from ._explorers import CompressionExplorer +from ...environment import AudioCraftEnvironment + + +@CompressionExplorer +def explorer(launcher): + partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) + launcher.slurm_(gpus=8, partition=partitions) + # use configuration for MusicGen's EnCodec model trained on monophonic audio sampled at 32 kHz + # MusicGen's EnCodec is trained with a total stride of 640 leading to a frame rate of 50 hz + launcher.bind_(solver='compression/encodec_musicgen_32khz') + # replace this by the desired music dataset + launcher.bind_(dset='internal/music_400k_32khz') + # launch xp + launcher() + launcher({ + 'metrics.visqol.bin': '/data/home/jadecopet/local/usr/opt/visqol', + 'label': 'visqol', + 'evaluate.metrics.visqol': True + }) diff --git a/audiocraft/grids/diffusion/4_bands_base_32khz.py b/audiocraft/grids/diffusion/4_bands_base_32khz.py new file mode 100644 index 0000000000000000000000000000000000000000..f7e67bcc89dd0c8e50d770e600b55f179fe19588 --- /dev/null +++ b/audiocraft/grids/diffusion/4_bands_base_32khz.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Training of the 4 diffusion models described in +"From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion" +(paper link). +""" + +from ._explorers import DiffusionExplorer + + +@DiffusionExplorer +def explorer(launcher): + launcher.slurm_(gpus=4, partition='learnfair') + + launcher.bind_({'solver': 'diffusion/default', + 'dset': 'internal/music_10k_32khz'}) + + with launcher.job_array(): + launcher({'filter.use': True, 'filter.idx_band': 0, "processor.use": False, 'processor.power_std': 0.4}) + launcher({'filter.use': True, 'filter.idx_band': 1, "processor.use": False, 'processor.power_std': 0.4}) + launcher({'filter.use': True, 'filter.idx_band': 2, "processor.use": True, 'processor.power_std': 0.4}) + launcher({'filter.use': True, 'filter.idx_band': 3, "processor.use": True, 'processor.power_std': 0.75}) diff --git a/audiocraft/grids/diffusion/__init__.py b/audiocraft/grids/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e5737294ae16c0de52085b8dcf6825c348f617e4 --- /dev/null +++ b/audiocraft/grids/diffusion/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Diffusion grids.""" diff --git a/audiocraft/grids/diffusion/_explorers.py b/audiocraft/grids/diffusion/_explorers.py new file mode 100644 index 0000000000000000000000000000000000000000..0bf4ca57b63f5f9308bd1178ddbde5d8f06748e5 --- /dev/null +++ b/audiocraft/grids/diffusion/_explorers.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import treetable as tt + +from .._base_explorers import BaseExplorer + + +class DiffusionExplorer(BaseExplorer): + eval_metrics = ["sisnr", "visqol"] + + def stages(self): + return ["train", "valid", "valid_ema", "evaluate", "evaluate_ema"] + + def get_grid_meta(self): + """Returns the list of Meta information to display for each XP/job. + """ + return [ + tt.leaf("index", align=">"), + tt.leaf("name", wrap=140), + tt.leaf("state"), + tt.leaf("sig", align=">"), + ] + + def get_grid_metrics(self): + """Return the metrics that should be displayed in the tracking table. + """ + return [ + tt.group( + "train", + [ + tt.leaf("epoch"), + tt.leaf("loss", ".3%"), + ], + align=">", + ), + tt.group( + "valid", + [ + tt.leaf("loss", ".3%"), + # tt.leaf("loss_0", ".3%"), + ], + align=">", + ), + tt.group( + "valid_ema", + [ + tt.leaf("loss", ".3%"), + # tt.leaf("loss_0", ".3%"), + ], + align=">", + ), + tt.group( + "evaluate", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"), + tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"), + tt.leaf("rvm_3", ".4f"), ], align=">" + ), + tt.group( + "evaluate_ema", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"), + tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"), + tt.leaf("rvm_3", ".4f")], align=">" + ), + ] diff --git a/audiocraft/grids/musicgen/__init__.py b/audiocraft/grids/musicgen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d3f101f5a29ff85271e44e4f27545168a8f27baa --- /dev/null +++ b/audiocraft/grids/musicgen/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""MusicGen grids.""" diff --git a/audiocraft/grids/musicgen/_explorers.py b/audiocraft/grids/musicgen/_explorers.py new file mode 100644 index 0000000000000000000000000000000000000000..334836b72559a120feb8a15eef3fe96ce88a4edb --- /dev/null +++ b/audiocraft/grids/musicgen/_explorers.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import typing as tp + +import treetable as tt + +from .._base_explorers import BaseExplorer + + +class LMExplorer(BaseExplorer): + eval_metrics: tp.List[str] = [] + + def stages(self) -> tp.List[str]: + return ['train', 'valid'] + + def get_grid_metrics(self): + """Return the metrics that should be displayed in the tracking table.""" + return [ + tt.group( + 'train', + [ + tt.leaf('epoch'), + tt.leaf('duration', '.1f'), # duration in minutes + tt.leaf('ping'), + tt.leaf('ce', '.4f'), # cross entropy + tt.leaf("ppl", '.3f'), # perplexity + ], + align='>', + ), + tt.group( + 'valid', + [ + tt.leaf('ce', '.4f'), + tt.leaf('ppl', '.3f'), + tt.leaf('best_ppl', '.3f'), + ], + align='>', + ), + ] + + def process_sheep(self, sheep, history): + parts = super().process_sheep(sheep, history) + + track_by = {'ppl': 'lower'} # values should be in ['lower', 'higher'] + best_metrics = {k: (1 if v == 'lower' else -1) * float('inf') for k, v in track_by.items()} + + def comparator(mode, a, b): + return a < b if mode == 'lower' else a > b + + for metrics in history: + for key, sub in metrics.items(): + for metric in track_by: + # for the validation set, keep track of best metrics (ppl in this example) + # this is so we can conveniently compare metrics between runs in the grid + if key == 'valid' and metric in sub and comparator( + track_by[metric], sub[metric], best_metrics[metric] + ): + best_metrics[metric] = sub[metric] + + if 'valid' in parts: + parts['valid'].update({f'best_{k}': v for k, v in best_metrics.items()}) + return parts + + +class GenerationEvalExplorer(BaseExplorer): + eval_metrics: tp.List[str] = [] + + def stages(self) -> tp.List[str]: + return ['evaluate'] + + def get_grid_metrics(self): + """Return the metrics that should be displayed in the tracking table.""" + return [ + tt.group( + 'evaluate', + [ + tt.leaf('epoch', '.3f'), + tt.leaf('duration', '.1f'), + tt.leaf('ping'), + tt.leaf('ce', '.4f'), + tt.leaf('ppl', '.3f'), + tt.leaf('fad', '.3f'), + tt.leaf('kld', '.3f'), + tt.leaf('text_consistency', '.3f'), + tt.leaf('chroma_cosine', '.3f'), + ], + align='>', + ), + ] diff --git a/audiocraft/grids/musicgen/musicgen_base_32khz.py b/audiocraft/grids/musicgen/musicgen_base_32khz.py new file mode 100644 index 0000000000000000000000000000000000000000..4e364614537e426f21c18a2c2a9d94b3babce051 --- /dev/null +++ b/audiocraft/grids/musicgen/musicgen_base_32khz.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from ._explorers import LMExplorer +from ...environment import AudioCraftEnvironment + + +@LMExplorer +def explorer(launcher): + partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) + launcher.slurm_(gpus=32, partition=partitions) + launcher.bind_(solver='musicgen/musicgen_base_32khz') + # replace this by the desired music dataset + launcher.bind_(dset='internal/music_400k_32khz') + + fsdp = {'autocast': False, 'fsdp.use': True} + medium = {'model/lm/model_scale': 'medium'} + large = {'model/lm/model_scale': 'large'} + + cfg_low = {'classifier_free_guidance.training_dropout': 0.2} + wd_low = {'conditioners.description.t5.word_dropout': 0.2} + + adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4} + + launcher.bind_(fsdp) + + launcher.slurm_(gpus=32).bind_(label='32gpus') + with launcher.job_array(): + sub = launcher.bind() + sub() + + launcher.slurm_(gpus=64).bind_(label='64gpus') + with launcher.job_array(): + sub = launcher.bind() + sub(medium, adam) + + launcher.slurm_(gpus=96).bind_(label='96gpus') + with launcher.job_array(): + sub = launcher.bind() + sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3}) diff --git a/audiocraft/grids/musicgen/musicgen_base_cached_32khz.py b/audiocraft/grids/musicgen/musicgen_base_cached_32khz.py new file mode 100644 index 0000000000000000000000000000000000000000..d9a43f37d7369b5de4542fba87c4c8739d58b1e8 --- /dev/null +++ b/audiocraft/grids/musicgen/musicgen_base_cached_32khz.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from ._explorers import LMExplorer +from ...environment import AudioCraftEnvironment + + +@LMExplorer +def explorer(launcher): + partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) + launcher.slurm_(gpus=32, partition=partitions) + launcher.bind_(solver='musicgen/musicgen_base_32khz') + # replace this by the desired music dataset + launcher.bind_(dset='internal/music_400k_32khz') + + fsdp = {'autocast': False, 'fsdp.use': True} + medium = {'model/lm/model_scale': 'medium'} + large = {'model/lm/model_scale': 'large'} + + cfg_low = {'classifier_free_guidance.training_dropout': 0.2} + wd_low = {'conditioners.description.t5.word_dropout': 0.2} + + adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4} + + # BEGINNING OF CACHE WRITING JOBS. + cache_write = { + 'cache.path': '/fsx-codegen/defossez/cache/interleave_stereo_nv_32k', + 'cache.write': True, + 'generate.every': 500, + 'evaluate.every': 500, + 'logging.log_updates': 50, + } + + cache_sub = launcher.bind({'model/lm/model_scale': 'xsmall', 'conditioner': 'none'}) + cache_sub.bind_({'deadlock.use': True}) + cache_sub.slurm_(gpus=8) + with launcher.job_array(): + num_shards = 10 # total number of jobs running in parallel. + for shard in range(0, num_shards): + launcher(cache_write, {'cache.write_num_shards': num_shards, 'cache.write_shard': shard}) + + # REMOVE THE FOLLOWING RETURN STATEMENT ONCE THE ABOVE JOBS ARE DONE, + # OR SUFFICIENTLY AHEAD. + return + + cache = { + 'cache.path': '/fsx-codegen/defossez/cache/interleave_stereo_nv_32k', + } + launcher.bind_(fsdp, cache) + + launcher.slurm_(gpus=32).bind_(label='32gpus') + with launcher.job_array(): + sub = launcher.bind() + sub() + + launcher.slurm_(gpus=64).bind_(label='64gpus') + with launcher.job_array(): + sub = launcher.bind() + sub(medium, adam) + + launcher.slurm_(gpus=96).bind_(label='96gpus') + with launcher.job_array(): + sub = launcher.bind() + sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3}) diff --git a/audiocraft/grids/musicgen/musicgen_clapemb_32khz.py b/audiocraft/grids/musicgen/musicgen_clapemb_32khz.py new file mode 100644 index 0000000000000000000000000000000000000000..64ad3f8c77afe1ab5908e407ad14d4879e1b1ad1 --- /dev/null +++ b/audiocraft/grids/musicgen/musicgen_clapemb_32khz.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from ._explorers import LMExplorer +from ...environment import AudioCraftEnvironment + + +@LMExplorer +def explorer(launcher): + partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) + launcher.slurm_(gpus=32, partition=partitions) + launcher.bind_(solver='musicgen/musicgen_base_32khz') + # replace this by the desired music dataset + launcher.bind_(dset='internal/music_400k_32khz') + launcher.bind_(conditioner='clapemb2music') + + fsdp = {'autocast': False, 'fsdp.use': True} + cache_path = {'conditioners.description.clap.cache_path': + '/fsx-audio-craft-llm/jadecopet/experiments/audiocraft/caches/clap_embed_music'} + text_wav_training_opt = {'conditioners.description.clap.text_p': 0.5} + + launcher.bind_(fsdp) + + launcher.slurm_(gpus=32).bind_(label='32gpus') + with launcher.job_array(): + launcher() + launcher(text_wav_training_opt) + launcher(cache_path) + launcher(cache_path, text_wav_training_opt) diff --git a/audiocraft/grids/musicgen/musicgen_melody_32khz.py b/audiocraft/grids/musicgen/musicgen_melody_32khz.py new file mode 100644 index 0000000000000000000000000000000000000000..b0d6710a23c117406e9724057a62eccab88ce907 --- /dev/null +++ b/audiocraft/grids/musicgen/musicgen_melody_32khz.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from ._explorers import LMExplorer +from ...environment import AudioCraftEnvironment + + +@LMExplorer +def explorer(launcher): + partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) + launcher.slurm_(gpus=32, partition=partitions) + launcher.bind_(solver='musicgen/musicgen_melody_32khz') + # replace this by the desired music dataset + launcher.bind_(dset='internal/music_400k_32khz') + + fsdp = {'autocast': False, 'fsdp.use': True} + medium = {'model/lm/model_scale': 'medium'} + large = {'model/lm/model_scale': 'large'} + + cfg_low = {'classifier_free_guidance.training_dropout': 0.2} + wd_low = {'conditioners.description.t5.word_dropout': 0.2} + + adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4} + + cache_path = {'conditioners.self_wav.chroma_stem.cache_path': + '/fsx-audio-craft-llm/jadecopet/experiments/audiocraft/caches/chroma_stem'} + + # CACHE GENERATION JOBS + n_cache_gen_jobs = 4 + gen_sub = launcher.slurm(gpus=1) + gen_sub.bind_( + cache_path, { + # the cache is always computed over the whole file, so duration doesn't matter here. + 'dataset.segment_duration': 2., + 'dataset.batch_size': 8, + 'dataset.train.permutation_on_files': True, # try to not repeat files. + 'optim.epochs': 10, + 'model/lm/model_scale': 'xsmall', + + }) + with gen_sub.job_array(): + for gen_job in range(n_cache_gen_jobs): + gen_sub({'dataset.train.shuffle_seed': gen_job}) + + # ACTUAL TRAINING JOBS. + launcher.bind_(fsdp) + + launcher.slurm_(gpus=32).bind_(label='32gpus') + with launcher.job_array(): + sub = launcher.bind() + sub() + sub(cache_path) + + launcher.slurm_(gpus=64).bind_(label='64gpus') + with launcher.job_array(): + sub = launcher.bind() + sub(medium, adam) + + launcher.slurm_(gpus=96).bind_(label='96gpus') + with launcher.job_array(): + sub = launcher.bind() + sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3}) diff --git a/audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py b/audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..39ceaf7dab15ec3f0f669cfe57ca9e932a9ab40d --- /dev/null +++ b/audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Evaluation with objective metrics for the pretrained MusicGen models. +This grid takes signature from the training grid and runs evaluation-only stage. + +When running the grid for the first time, please use: +REGEN=1 dora grid musicgen.musicgen_pretrained_32khz_eval +and re-use the REGEN=1 option when the grid is changed to force regenerating it. + +Note that you need the proper metrics external libraries setup to use all +the objective metrics activated in this grid. Refer to the README for more information. +""" + +import os + +from ._explorers import GenerationEvalExplorer +from ...environment import AudioCraftEnvironment +from ... import train + + +def eval(launcher, batch_size: int = 32, eval_melody: bool = False): + opts = { + 'dset': 'audio/musiccaps_32khz', + 'solver/musicgen/evaluation': 'objective_eval', + 'execute_only': 'evaluate', + '+dataset.evaluate.batch_size': batch_size, + '+metrics.fad.tf.batch_size': 16, + } + # chroma-specific evaluation + chroma_opts = { + 'dset': 'internal/music_400k_32khz', + 'dataset.evaluate.segment_duration': 30, + 'dataset.evaluate.num_samples': 1000, + 'evaluate.metrics.chroma_cosine': True, + 'evaluate.metrics.fad': False, + 'evaluate.metrics.kld': False, + 'evaluate.metrics.text_consistency': False, + } + # binary for FAD computation: replace this path with your own path + metrics_opts = { + 'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research' + } + opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.} + opt2 = {'transformer_lm.two_step_cfg': True} + + sub = launcher.bind(opts) + sub.bind_(metrics_opts) + + # base objective metrics + sub(opt1, opt2) + + if eval_melody: + # chroma-specific metrics + sub(opt1, opt2, chroma_opts) + + +@GenerationEvalExplorer +def explorer(launcher): + partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) + launcher.slurm_(gpus=4, partition=partitions) + + if 'REGEN' not in os.environ: + folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1] + with launcher.job_array(): + for sig in folder.iterdir(): + if not sig.is_symlink(): + continue + xp = train.main.get_xp_from_sig(sig.name) + launcher(xp.argv) + return + + with launcher.job_array(): + musicgen_base = launcher.bind(solver="musicgen/musicgen_base_32khz") + musicgen_base.bind_({'autocast': False, 'fsdp.use': True}) + + # base musicgen models + musicgen_base_small = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-small'}) + eval(musicgen_base_small, batch_size=128) + + musicgen_base_medium = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-medium'}) + musicgen_base_medium.bind_({'model/lm/model_scale': 'medium'}) + eval(musicgen_base_medium, batch_size=128) + + musicgen_base_large = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-large'}) + musicgen_base_large.bind_({'model/lm/model_scale': 'large'}) + eval(musicgen_base_large, batch_size=128) + + # melody musicgen model + musicgen_melody = launcher.bind(solver="musicgen/musicgen_melody_32khz") + musicgen_melody.bind_({'autocast': False, 'fsdp.use': True}) + + musicgen_melody_medium = musicgen_melody.bind({'continue_from': '//pretrained/facebook/musicgen-melody'}) + musicgen_melody_medium.bind_({'model/lm/model_scale': 'medium'}) + eval(musicgen_melody_medium, batch_size=128, eval_melody=True) diff --git a/audiocraft/grids/musicgen/musicgen_stereo_finetune_32khz.py b/audiocraft/grids/musicgen/musicgen_stereo_finetune_32khz.py new file mode 100644 index 0000000000000000000000000000000000000000..2904e73de08f1c9b844818558d739715776284d6 --- /dev/null +++ b/audiocraft/grids/musicgen/musicgen_stereo_finetune_32khz.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from pathlib import Path +from ._explorers import LMExplorer +from ...environment import AudioCraftEnvironment + + +@LMExplorer +def explorer(launcher): + partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) + launcher.slurm_(gpus=32, partition=partitions) + launcher.bind_(solver='musicgen/musicgen_base_32khz') + # replace this by the desired music dataset, which needs to be stereo + launcher.bind_(dset='audio/example') + + fsdp = {'autocast': False, 'fsdp.use': True} + medium = {'model/lm/model_scale': 'medium'} + large = {'model/lm/model_scale': 'large'} + + cfg_low = {'classifier_free_guidance.training_dropout': 0.2} + wd_low = {'conditioners.description.t5.word_dropout': 0.2} + + adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4} + + stereo = { + 'codebooks_pattern.delay.delays': [0, 0, 1, 1, 2, 2, 3, 3], + 'transformer_lm.n_q': 8, + 'interleave_stereo_codebooks.use': True, + 'channels': 2, + } + + # You must follow the instructions in docs/MUSICGEN.md about the creation + # of the proper fine tuning checkpoints. We will assume they are stored under + # ~/checkpoints/{mode_name}. + + checkpoints = Path.home() / 'checkpoints' + + launcher.bind_(fsdp, stereo, {'optim.epochs': 100}) + + launcher.slurm_(gpus=32).bind_(label='32gpus') + with launcher.job_array(): + sub = launcher.bind({'continue_from': str(checkpoints / 'stereo_finetune_musicgen-small.th')}) + sub() + + launcher.slurm_(gpus=64).bind_(label='64gpus') + with launcher.job_array(): + sub = launcher.bind({'continue_from': str(checkpoints / 'stereo_finetune_musicgen-medium.th')}) + sub(medium, adam) + + launcher.slurm_(gpus=96).bind_(label='96gpus') + with launcher.job_array(): + sub = launcher.bind({'continue_from': str(checkpoints / 'stereo_finetune_musicgen-large.th')}) + sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3}) diff --git a/audiocraft/losses/__init__.py b/audiocraft/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d55107b2c11822cab749ed3683cf19020802898a --- /dev/null +++ b/audiocraft/losses/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Loss related classes and functions. In particular the loss balancer from +EnCodec, and the usual spectral losses.""" + +# flake8: noqa +from .balancer import Balancer +from .sisnr import SISNR +from .stftloss import ( + LogSTFTMagnitudeLoss, + MRSTFTLoss, + SpectralConvergenceLoss, + STFTLoss +) +from .specloss import ( + MelSpectrogramL1Loss, + MultiScaleMelSpectrogramLoss, +) diff --git a/audiocraft/losses/balancer.py b/audiocraft/losses/balancer.py new file mode 100644 index 0000000000000000000000000000000000000000..8a0ac8adebab8cdee8f82351965195dc02800d18 --- /dev/null +++ b/audiocraft/losses/balancer.py @@ -0,0 +1,136 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import typing as tp + +import flashy +import torch +from torch import autograd + + +class Balancer: + """Loss balancer. + + The loss balancer combines losses together to compute gradients for the backward. + Given `y = f(...)`, and a number of losses `l1(y, ...)`, `l2(y, ...)`, with `...` + not having any dependence on `f`, the balancer can efficiently normalize the partial gradients + `d l1 / d y`, `d l2 / dy` before summing them in order to achieve a desired ratio between + the losses. For instance if `weights = {'l1': 2, 'l2': 1}`, 66% of the gradient + going into `f(...)` will come from `l1` on average, and 33% from `l2`. This allows for an easy + interpration of the weights even if the intrisic scale of `l1`, `l2` ... is unknown. + + Noting `g1 = d l1 / dy`, etc., the balanced gradient `G` will be + (with `avg` an exponential moving average over the updates), + + G = sum_i total_norm * g_i / avg(||g_i||) * w_i / sum(w_i) + + If `balance_grads` is False, this is deactivated, and instead the gradient will just be the + standard sum of the partial gradients with the given weights. + + A call to the backward method of the balancer will compute the the partial gradients, + combining all the losses and potentially rescaling the gradients, + which can help stabilize the training and reason about multiple losses with varying scales. + The obtained gradient with respect to `y` is then back-propagated to `f(...)`. + + Expected usage: + + weights = {'loss_a': 1, 'loss_b': 4} + balancer = Balancer(weights, ...) + losses: dict = {} + losses['loss_a'] = compute_loss_a(x, y) + losses['loss_b'] = compute_loss_b(x, y) + if model.training(): + effective_loss = balancer.backward(losses, x) + + Args: + weights (dict[str, float]): Weight coefficient for each loss. The balancer expect the losses keys + from the backward method to match the weights keys to assign weight to each of the provided loss. + balance_grads (bool): Whether to rescale gradients so that weights reflect the fraction of the + overall gradient, rather than a constant multiplier. + total_norm (float): Reference norm when rescaling gradients, ignored otherwise. + emay_decay (float): EMA decay for averaging the norms. + per_batch_item (bool): Whether to compute the averaged norm per batch item or not. This only holds + when rescaling the gradients. + epsilon (float): Epsilon value for numerical stability. + monitor (bool): If True, stores in `self.metrics` the relative ratio between the norm of the gradients + coming from each loss, when calling `backward()`. + """ + def __init__(self, weights: tp.Dict[str, float], balance_grads: bool = True, total_norm: float = 1., + ema_decay: float = 0.999, per_batch_item: bool = True, epsilon: float = 1e-12, + monitor: bool = False): + self.weights = weights + self.per_batch_item = per_batch_item + self.total_norm = total_norm or 1. + self.averager = flashy.averager(ema_decay or 1.) + self.epsilon = epsilon + self.monitor = monitor + self.balance_grads = balance_grads + self._metrics: tp.Dict[str, tp.Any] = {} + + @property + def metrics(self): + return self._metrics + + def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor) -> torch.Tensor: + """Compute the backward and return the effective train loss, e.g. the loss obtained from + computing the effective weights. If `balance_grads` is True, the effective weights + are the one that needs to be applied to each gradient to respect the desired relative + scale of gradients coming from each loss. + + Args: + losses (Dict[str, torch.Tensor]): dictionary with the same keys as `self.weights`. + input (torch.Tensor): the input of the losses, typically the output of the model. + This should be the single point of dependence between the losses + and the model being trained. + """ + norms = {} + grads = {} + for name, loss in losses.items(): + # Compute partial derivative of the less with respect to the input. + grad, = autograd.grad(loss, [input], retain_graph=True) + if self.per_batch_item: + # We do not average the gradient over the batch dimension. + dims = tuple(range(1, grad.dim())) + norm = grad.norm(dim=dims, p=2).mean() + else: + norm = grad.norm(p=2) + norms[name] = norm + grads[name] = grad + + count = 1 + if self.per_batch_item: + count = len(grad) + # Average norms across workers. Theoretically we should average the + # squared norm, then take the sqrt, but it worked fine like that. + avg_norms = flashy.distrib.average_metrics(self.averager(norms), count) + # We approximate the total norm of the gradient as the sums of the norms. + # Obviously this can be very incorrect if all gradients are aligned, but it works fine. + total = sum(avg_norms.values()) + + self._metrics = {} + if self.monitor: + # Store the ratio of the total gradient represented by each loss. + for k, v in avg_norms.items(): + self._metrics[f'ratio_{k}'] = v / total + + total_weights = sum([self.weights[k] for k in avg_norms]) + assert total_weights > 0. + desired_ratios = {k: w / total_weights for k, w in self.weights.items()} + + out_grad = torch.zeros_like(input) + effective_loss = torch.tensor(0., device=input.device, dtype=input.dtype) + for name, avg_norm in avg_norms.items(): + if self.balance_grads: + # g_balanced = g / avg(||g||) * total_norm * desired_ratio + scale = desired_ratios[name] * self.total_norm / (self.epsilon + avg_norm) + else: + # We just do regular weighted sum of the gradients. + scale = self.weights[name] + out_grad.add_(grads[name], alpha=scale) + effective_loss += scale * losses[name].detach() + # Send the computed partial derivative with respect to the output of the model to the model. + input.backward(out_grad) + return effective_loss diff --git a/audiocraft/losses/sisnr.py b/audiocraft/losses/sisnr.py new file mode 100644 index 0000000000000000000000000000000000000000..a1b8ee03507dccf0327b1f2f57298b56f38827fe --- /dev/null +++ b/audiocraft/losses/sisnr.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import typing as tp + +import torch +from torch import nn +from torch.nn import functional as F + + +def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor: + """Given input of size [*OT, T], output Tensor of size [*OT, F, K] + with K the kernel size, by extracting frames with the given stride. + This will pad the input so that `F = ceil(T / K)`. + see https://github.com/pytorch/pytorch/issues/60466 + """ + *shape, length = a.shape + n_frames = math.ceil(length / stride) + tgt_length = (n_frames - 1) * stride + kernel_size + a = F.pad(a, (0, tgt_length - length)) + strides = list(a.stride()) + assert strides[-1] == 1, "data should be contiguous" + strides = strides[:-1] + [stride, 1] + return a.as_strided([*shape, n_frames, kernel_size], strides) + + +def _center(x: torch.Tensor) -> torch.Tensor: + return x - x.mean(-1, True) + + +def _norm2(x: torch.Tensor) -> torch.Tensor: + return x.pow(2).sum(-1, True) + + +class SISNR(nn.Module): + """SISNR loss. + + Input should be [B, C, T], output is scalar. + + ..Warning:: This function returns the opposite of the SI-SNR (e.g. `-1 * regular_SI_SNR`). + Consequently, lower scores are better in terms of reconstruction quality, + in particular, it should be negative if training goes well. This done this way so + that this module can also be used as a loss function for training model. + + Args: + sample_rate (int): Sample rate. + segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on + entire audio only. + overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap. + epsilon (float): Epsilon value for numerical stability. + """ + def __init__( + self, + sample_rate: int = 16000, + segment: tp.Optional[float] = 20, + overlap: float = 0.5, + epsilon: float = torch.finfo(torch.float32).eps, + ): + super().__init__() + self.sample_rate = sample_rate + self.segment = segment + self.overlap = overlap + self.epsilon = epsilon + + def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor: + B, C, T = ref_sig.shape + assert ref_sig.shape == out_sig.shape + + if self.segment is None: + frame = T + stride = T + else: + frame = int(self.segment * self.sample_rate) + stride = int(frame * (1 - self.overlap)) + + epsilon = self.epsilon * frame # make epsilon prop to frame size. + + gt = _unfold(ref_sig, frame, stride) + est = _unfold(out_sig, frame, stride) + if self.segment is None: + assert gt.shape[-1] == 1 + + gt = _center(gt) + est = _center(est) + dot = torch.einsum("bcft,bcft->bcf", gt, est) + + proj = dot[:, :, :, None] * gt / (epsilon + _norm2(gt)) + noise = est - proj + + sisnr = 10 * ( + torch.log10(epsilon + _norm2(proj)) - torch.log10(epsilon + _norm2(noise)) + ) + return -1 * sisnr[..., 0].mean() diff --git a/audiocraft/losses/specloss.py b/audiocraft/losses/specloss.py new file mode 100644 index 0000000000000000000000000000000000000000..11f2eb3e5c44b542a02f13db64bfb22fa0d3d212 --- /dev/null +++ b/audiocraft/losses/specloss.py @@ -0,0 +1,149 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import typing as tp + +import numpy as np +from torchaudio.transforms import MelSpectrogram +import torch +from torch import nn +from torch.nn import functional as F + +from ..modules import pad_for_conv1d + + +class MelSpectrogramWrapper(nn.Module): + """Wrapper around MelSpectrogram torchaudio transform providing proper padding + and additional post-processing including log scaling. + + Args: + n_mels (int): Number of mel bins. + n_fft (int): Number of fft. + hop_length (int): Hop size. + win_length (int): Window length. + n_mels (int): Number of mel bins. + sample_rate (int): Sample rate. + f_min (float or None): Minimum frequency. + f_max (float or None): Maximum frequency. + log (bool): Whether to scale with log. + normalized (bool): Whether to normalize the melspectrogram. + floor_level (float): Floor level based on human perception (default=1e-5). + """ + def __init__(self, n_fft: int = 1024, hop_length: int = 256, win_length: tp.Optional[int] = None, + n_mels: int = 80, sample_rate: float = 22050, f_min: float = 0.0, f_max: tp.Optional[float] = None, + log: bool = True, normalized: bool = False, floor_level: float = 1e-5): + super().__init__() + self.n_fft = n_fft + hop_length = int(hop_length) + self.hop_length = hop_length + self.mel_transform = MelSpectrogram(n_mels=n_mels, sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, + win_length=win_length, f_min=f_min, f_max=f_max, normalized=normalized, + window_fn=torch.hann_window, center=False) + self.floor_level = floor_level + self.log = log + + def forward(self, x): + p = int((self.n_fft - self.hop_length) // 2) + if len(x.shape) == 2: + x = x.unsqueeze(1) + x = F.pad(x, (p, p), "reflect") + # Make sure that all the frames are full. + # The combination of `pad_for_conv1d` and the above padding + # will make the output of size ceil(T / hop). + x = pad_for_conv1d(x, self.n_fft, self.hop_length) + self.mel_transform.to(x.device) + mel_spec = self.mel_transform(x) + B, C, freqs, frame = mel_spec.shape + if self.log: + mel_spec = torch.log10(self.floor_level + mel_spec) + return mel_spec.reshape(B, C * freqs, frame) + + +class MelSpectrogramL1Loss(torch.nn.Module): + """L1 Loss on MelSpectrogram. + + Args: + sample_rate (int): Sample rate. + n_fft (int): Number of fft. + hop_length (int): Hop size. + win_length (int): Window length. + n_mels (int): Number of mel bins. + f_min (float or None): Minimum frequency. + f_max (float or None): Maximum frequency. + log (bool): Whether to scale with log. + normalized (bool): Whether to normalize the melspectrogram. + floor_level (float): Floor level value based on human perception (default=1e-5). + """ + def __init__(self, sample_rate: int, n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, + n_mels: int = 80, f_min: float = 0.0, f_max: tp.Optional[float] = None, + log: bool = True, normalized: bool = False, floor_level: float = 1e-5): + super().__init__() + self.l1 = torch.nn.L1Loss() + self.melspec = MelSpectrogramWrapper(n_fft=n_fft, hop_length=hop_length, win_length=win_length, + n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max, + log=log, normalized=normalized, floor_level=floor_level) + + def forward(self, x, y): + self.melspec.to(x.device) + s_x = self.melspec(x) + s_y = self.melspec(y) + return self.l1(s_x, s_y) + + +class MultiScaleMelSpectrogramLoss(nn.Module): + """Multi-Scale spectrogram loss (msspec). + + Args: + sample_rate (int): Sample rate. + range_start (int): Power of 2 to use for the first scale. + range_stop (int): Power of 2 to use for the last scale. + n_mels (int): Number of mel bins. + f_min (float): Minimum frequency. + f_max (float or None): Maximum frequency. + normalized (bool): Whether to normalize the melspectrogram. + alphas (bool): Whether to use alphas as coefficients or not. + floor_level (float): Floor level value based on human perception (default=1e-5). + """ + def __init__(self, sample_rate: int, range_start: int = 6, range_end: int = 11, + n_mels: int = 64, f_min: float = 0.0, f_max: tp.Optional[float] = None, + normalized: bool = False, alphas: bool = True, floor_level: float = 1e-5): + super().__init__() + l1s = list() + l2s = list() + self.alphas = list() + self.total = 0 + self.normalized = normalized + for i in range(range_start, range_end): + l1s.append( + MelSpectrogramWrapper(n_fft=2 ** i, hop_length=(2 ** i) / 4, win_length=2 ** i, + n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max, + log=False, normalized=normalized, floor_level=floor_level)) + l2s.append( + MelSpectrogramWrapper(n_fft=2 ** i, hop_length=(2 ** i) / 4, win_length=2 ** i, + n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max, + log=True, normalized=normalized, floor_level=floor_level)) + if alphas: + self.alphas.append(np.sqrt(2 ** i - 1)) + else: + self.alphas.append(1) + self.total += self.alphas[-1] + 1 + + self.l1s = nn.ModuleList(l1s) + self.l2s = nn.ModuleList(l2s) + + def forward(self, x, y): + loss = 0.0 + self.l1s.to(x.device) + self.l2s.to(x.device) + for i in range(len(self.alphas)): + s_x_1 = self.l1s[i](x) + s_y_1 = self.l1s[i](y) + s_x_2 = self.l2s[i](x) + s_y_2 = self.l2s[i](y) + loss += F.l1_loss(s_x_1, s_y_1) + self.alphas[i] * F.mse_loss(s_x_2, s_y_2) + if self.normalized: + loss = loss / self.total + return loss diff --git a/audiocraft/losses/stftloss.py b/audiocraft/losses/stftloss.py new file mode 100644 index 0000000000000000000000000000000000000000..5ad4b7d3324ee5b0e6064b6f71cf8caf0fdc3be7 --- /dev/null +++ b/audiocraft/losses/stftloss.py @@ -0,0 +1,207 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# Adapted from MIT code under the original license +# Copyright 2019 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) +import typing as tp + +import torch +from torch import nn +from torch.nn import functional as F + + +# TODO: Replace with torchaudio.STFT? +def _stft(x: torch.Tensor, fft_size: int, hop_length: int, win_length: int, + window: tp.Optional[torch.Tensor], normalized: bool) -> torch.Tensor: + """Perform STFT and convert to magnitude spectrogram. + + Args: + x: Input signal tensor (B, C, T). + fft_size (int): FFT size. + hop_length (int): Hop size. + win_length (int): Window length. + window (torch.Tensor or None): Window function type. + normalized (bool): Whether to normalize the STFT or not. + + Returns: + torch.Tensor: Magnitude spectrogram (B, C, #frames, fft_size // 2 + 1). + """ + B, C, T = x.shape + x_stft = torch.stft( + x.view(-1, T), fft_size, hop_length, win_length, window, + normalized=normalized, return_complex=True, + ) + x_stft = x_stft.view(B, C, *x_stft.shape[1:]) + real = x_stft.real + imag = x_stft.imag + + # NOTE(kan-bayashi): clamp is needed to avoid nan or inf + return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1) + + +class SpectralConvergenceLoss(nn.Module): + """Spectral convergence loss. + """ + def __init__(self, epsilon: float = torch.finfo(torch.float32).eps): + super().__init__() + self.epsilon = epsilon + + def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor): + """Calculate forward propagation. + + Args: + x_mag: Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). + y_mag: Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). + Returns: + torch.Tensor: Spectral convergence loss value. + """ + return torch.norm(y_mag - x_mag, p="fro") / (torch.norm(y_mag, p="fro") + self.epsilon) + + +class LogSTFTMagnitudeLoss(nn.Module): + """Log STFT magnitude loss. + + Args: + epsilon (float): Epsilon value for numerical stability. + """ + def __init__(self, epsilon: float = torch.finfo(torch.float32).eps): + super().__init__() + self.epsilon = epsilon + + def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor): + """Calculate forward propagation. + + Args: + x_mag (torch.Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). + y_mag (torch.Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). + Returns: + torch.Tensor: Log STFT magnitude loss value. + """ + return F.l1_loss(torch.log(self.epsilon + y_mag), torch.log(self.epsilon + x_mag)) + + +class STFTLosses(nn.Module): + """STFT losses. + + Args: + n_fft (int): Size of FFT. + hop_length (int): Hop length. + win_length (int): Window length. + window (str): Window function type. + normalized (bool): Whether to use normalized STFT or not. + epsilon (float): Epsilon for numerical stability. + """ + def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_length: int = 600, + window: str = "hann_window", normalized: bool = False, + epsilon: float = torch.finfo(torch.float32).eps): + super().__init__() + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.normalized = normalized + self.register_buffer("window", getattr(torch, window)(win_length)) + self.spectral_convergenge_loss = SpectralConvergenceLoss(epsilon) + self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss(epsilon) + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]: + """Calculate forward propagation. + + Args: + x (torch.Tensor): Predicted signal (B, T). + y (torch.Tensor): Groundtruth signal (B, T). + Returns: + torch.Tensor: Spectral convergence loss value. + torch.Tensor: Log STFT magnitude loss value. + """ + x_mag = _stft(x, self.n_fft, self.hop_length, + self.win_length, self.window, self.normalized) # type: ignore + y_mag = _stft(y, self.n_fft, self.hop_length, + self.win_length, self.window, self.normalized) # type: ignore + sc_loss = self.spectral_convergenge_loss(x_mag, y_mag) + mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) + + return sc_loss, mag_loss + + +class STFTLoss(nn.Module): + """Single Resolution STFT loss. + + Args: + n_fft (int): Nb of FFT. + hop_length (int): Hop length. + win_length (int): Window length. + window (str): Window function type. + normalized (bool): Whether to use normalized STFT or not. + epsilon (float): Epsilon for numerical stability. + factor_sc (float): Coefficient for the spectral loss. + factor_mag (float): Coefficient for the magnitude loss. + """ + def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_length: int = 600, + window: str = "hann_window", normalized: bool = False, + factor_sc: float = 0.1, factor_mag: float = 0.1, + epsilon: float = torch.finfo(torch.float32).eps): + super().__init__() + self.loss = STFTLosses(n_fft, hop_length, win_length, window, normalized, epsilon) + self.factor_sc = factor_sc + self.factor_mag = factor_mag + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]: + """Calculate forward propagation. + + Args: + x (torch.Tensor): Predicted signal (B, T). + y (torch.Tensor): Groundtruth signal (B, T). + Returns: + torch.Tensor: Single resolution STFT loss. + """ + sc_loss, mag_loss = self.loss(x, y) + return self.factor_sc * sc_loss + self.factor_mag * mag_loss + + +class MRSTFTLoss(nn.Module): + """Multi resolution STFT loss. + + Args: + n_ffts (Sequence[int]): Sequence of FFT sizes. + hop_lengths (Sequence[int]): Sequence of hop sizes. + win_lengths (Sequence[int]): Sequence of window lengths. + window (str): Window function type. + factor_sc (float): Coefficient for the spectral loss. + factor_mag (float): Coefficient for the magnitude loss. + normalized (bool): Whether to use normalized STFT or not. + epsilon (float): Epsilon for numerical stability. + """ + def __init__(self, n_ffts: tp.Sequence[int] = [1024, 2048, 512], hop_lengths: tp.Sequence[int] = [120, 240, 50], + win_lengths: tp.Sequence[int] = [600, 1200, 240], window: str = "hann_window", + factor_sc: float = 0.1, factor_mag: float = 0.1, + normalized: bool = False, epsilon: float = torch.finfo(torch.float32).eps): + super().__init__() + assert len(n_ffts) == len(hop_lengths) == len(win_lengths) + self.stft_losses = torch.nn.ModuleList() + for fs, ss, wl in zip(n_ffts, hop_lengths, win_lengths): + self.stft_losses += [STFTLosses(fs, ss, wl, window, normalized, epsilon)] + self.factor_sc = factor_sc + self.factor_mag = factor_mag + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (torch.Tensor): Predicted signal (B, T). + y (torch.Tensor): Groundtruth signal (B, T). + Returns: + torch.Tensor: Multi resolution STFT loss. + """ + sc_loss = torch.Tensor([0.0]) + mag_loss = torch.Tensor([0.0]) + for f in self.stft_losses: + sc_l, mag_l = f(x, y) + sc_loss += sc_l + mag_loss += mag_l + sc_loss /= len(self.stft_losses) + mag_loss /= len(self.stft_losses) + + return self.factor_sc * sc_loss + self.factor_mag * mag_loss diff --git a/audiocraft/metrics/__init__.py b/audiocraft/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3474bdc4f1c88b21904d2a21ba077c93a8a70c8b --- /dev/null +++ b/audiocraft/metrics/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Metrics like CLAP score, FAD, KLD, Visqol, Chroma similarity, etc. +""" +# flake8: noqa +from .clap_consistency import CLAPTextConsistencyMetric, TextConsistencyMetric +from .chroma_cosinesim import ChromaCosineSimilarityMetric +from .fad import FrechetAudioDistanceMetric +from .kld import KLDivergenceMetric, PasstKLDivergenceMetric +from .rvm import RelativeVolumeMel +from .visqol import ViSQOL diff --git a/audiocraft/metrics/chroma_cosinesim.py b/audiocraft/metrics/chroma_cosinesim.py new file mode 100644 index 0000000000000000000000000000000000000000..40c26081b803c2017fae1b6d7d086f0b0e074cef --- /dev/null +++ b/audiocraft/metrics/chroma_cosinesim.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torchmetrics + +from ..data.audio_utils import convert_audio +from ..modules.chroma import ChromaExtractor + + +class ChromaCosineSimilarityMetric(torchmetrics.Metric): + """Chroma cosine similarity metric. + + This metric extracts a chromagram for a reference waveform and + a generated waveform and compares each frame using the cosine similarity + function. The output is the mean cosine similarity. + + Args: + sample_rate (int): Sample rate used by the chroma extractor. + n_chroma (int): Number of chroma used by the chroma extractor. + radix2_exp (int): Exponent for the chroma extractor. + argmax (bool): Whether the chroma extractor uses argmax. + eps (float): Epsilon for cosine similarity computation. + """ + def __init__(self, sample_rate: int, n_chroma: int, radix2_exp: int, argmax: bool, eps: float = 1e-8): + super().__init__() + self.chroma_sample_rate = sample_rate + self.n_chroma = n_chroma + self.eps = eps + self.chroma_extractor = ChromaExtractor(sample_rate=self.chroma_sample_rate, n_chroma=self.n_chroma, + radix2_exp=radix2_exp, argmax=argmax) + self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum") + self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, targets: torch.Tensor, + sizes: torch.Tensor, sample_rates: torch.Tensor) -> None: + """Compute cosine similarity between chromagrams and accumulate scores over the dataset.""" + if preds.size(0) == 0: + return + + assert preds.shape == targets.shape, ( + f"Preds and target shapes mismatch: preds={preds.shape}, targets={targets.shape}") + assert preds.size(0) == sizes.size(0), ( + f"Number of items in preds ({preds.shape}) mismatch ", + f"with sizes ({sizes.shape})") + assert preds.size(0) == sample_rates.size(0), ( + f"Number of items in preds ({preds.shape}) mismatch ", + f"with sample_rates ({sample_rates.shape})") + assert torch.all(sample_rates == sample_rates[0].item()), "All sample rates are not the same in the batch" + + device = self.weight.device + preds, targets = preds.to(device), targets.to(device) # type: ignore + sample_rate = sample_rates[0].item() + preds = convert_audio(preds, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1) + targets = convert_audio(targets, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1) + gt_chroma = self.chroma_extractor(targets) + gen_chroma = self.chroma_extractor(preds) + chroma_lens = (sizes / self.chroma_extractor.winhop).ceil().int() + for i in range(len(gt_chroma)): + t = int(chroma_lens[i].item()) + cosine_sim = torch.nn.functional.cosine_similarity( + gt_chroma[i, :t], gen_chroma[i, :t], dim=1, eps=self.eps) + self.cosine_sum += cosine_sim.sum(dim=0) # type: ignore + self.weight += torch.tensor(t) # type: ignore + + def compute(self) -> float: + """Computes the average cosine similarty across all generated/target chromagrams pairs.""" + assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore + return (self.cosine_sum / self.weight).item() # type: ignore diff --git a/audiocraft/metrics/clap_consistency.py b/audiocraft/metrics/clap_consistency.py new file mode 100644 index 0000000000000000000000000000000000000000..d2a6c61ae177533ca2fb17e25bc77d2acbbe3791 --- /dev/null +++ b/audiocraft/metrics/clap_consistency.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from pathlib import Path +import typing as tp + +import torch +import torchmetrics +from transformers import RobertaTokenizer # type: ignore + +from ..data.audio_utils import convert_audio +from ..environment import AudioCraftEnvironment +from ..utils.utils import load_clap_state_dict + +try: + import laion_clap # type: ignore +except ImportError: + laion_clap = None + + +class TextConsistencyMetric(torchmetrics.Metric): + """Text consistency metric measuring consistency between audio and text pairs.""" + + def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None: + raise NotImplementedError("implement how to update the metric from the audio and text pairs.") + + def compute(self): + raise NotImplementedError("implement how to compute the final metric score.") + + +class CLAPTextConsistencyMetric(TextConsistencyMetric): + """Text consistency metric relying on Contrastive Language-Audio Pretraining (CLAP). + + This metric is similar to the MuLan Cycle Consistency from MusicLM (https://arxiv.org/pdf/2301.11325.pdf) + or the CLAP score used in Make-An-Audio (https://arxiv.org/pdf/2301.12661v1.pdf). + + As a joint audio-text embedding model, a pretrained CLAP model can be used to quantify the + similarity between audio-text pairs. We compute the CLAP embeddings from the text descriptions as + well as the generated audio based on them, and define the MCC metric as the average cosine similarity + between these embeddings. + + Model implementation & pre-trained checkpoints: https://github.com/LAION-AI/CLAP + """ + def __init__(self, model_path: tp.Union[str, Path], model_arch: str = 'HTSAT-tiny', enable_fusion: bool = False): + super().__init__() + if laion_clap is None: + raise ImportError("Please install CLAP to compute text consistency: 'pip install laion_clap'") + self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum") + self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum") + self._initialize_model(model_path, model_arch, enable_fusion) + + def _initialize_model(self, model_path: tp.Union[str, Path], model_arch: str, enable_fusion: bool): + model_path = AudioCraftEnvironment.resolve_reference_path(model_path) + self.tokenize = RobertaTokenizer.from_pretrained('roberta-base') + self.model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch) + self.model_sample_rate = 48_000 + load_clap_state_dict(self.model, model_path) + self.model.eval() + + def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict: + # we use the default params from CLAP module here as well + return self.tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt") + + def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None: + """Compute cosine similarity between audio and text pairs and accumulate scores over the dataset.""" + assert audio.size(0) == len(text), "Number of audio and text samples should match" + assert torch.all(sample_rates == sample_rates[0].item()), "All items in batch should have the same sample rate" + sample_rate = int(sample_rates[0].item()) + # convert audio batch to 48kHz monophonic audio with no channel dimension: [B, C, T] -> [B, T] + audio = convert_audio(audio, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1).mean(dim=1) + audio_embeddings = self.model.get_audio_embedding_from_data(audio, use_tensor=True) + text_embeddings = self.model.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True) + # cosine similarity between the text and the audio embedding + cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_embeddings, dim=1, eps=1e-8) + self.cosine_sum += cosine_sim.sum(dim=0) + self.weight += torch.tensor(cosine_sim.size(0)) + + def compute(self): + """Computes the average cosine similarty across all audio/text pairs.""" + assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore + return (self.cosine_sum / self.weight).item() # type: ignore diff --git a/audiocraft/metrics/fad.py b/audiocraft/metrics/fad.py new file mode 100644 index 0000000000000000000000000000000000000000..de66138dbb14fd4246bbfe590bddfd5beaf1ed8c --- /dev/null +++ b/audiocraft/metrics/fad.py @@ -0,0 +1,329 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from pathlib import Path +import os +import subprocess +import tempfile +import typing as tp + +from audiocraft.data.audio import audio_write +from audiocraft.data.audio_utils import convert_audio +import flashy +import torch +import torchmetrics + +from ..environment import AudioCraftEnvironment + + +logger = logging.getLogger(__name__) + +VGGISH_SAMPLE_RATE = 16_000 +VGGISH_CHANNELS = 1 + + +class FrechetAudioDistanceMetric(torchmetrics.Metric): + """Fréchet Audio Distance computation based on official TensorFlow implementation from Google Research. + + From: D.C. Dowson & B.V. Landau The Fréchet distance between + multivariate normal distributions + https://doi.org/10.1016/0047-259X(82)90077-X + The Fréchet distance between two multivariate gaussians, + `X ~ N(mu_x, sigma_x)` and `Y ~ N(mu_y, sigma_y)`, is `d^2`. + d^2 = (mu_x - mu_y)^2 + Tr(sigma_x + sigma_y - 2 * sqrt(sigma_x*sigma_y)) + = (mu_x - mu_y)^2 + Tr(sigma_x) + Tr(sigma_y) + - 2 * Tr(sqrt(sigma_x*sigma_y))) + + To use this FAD computation metric, you need to have the proper Frechet Audio Distance tool setup + from: https://github.com/google-research/google-research/tree/master/frechet_audio_distance + We provide the below instructions as reference but we do not guarantee for further support + in frechet_audio_distance installation. This was tested with python 3.10, cuda 11.8, tensorflow 2.12.0. + + We recommend installing the frechet_audio_distance library in a dedicated env (e.g. conda). + + 1. Get the code and models following the repository instructions. We used the steps below: + git clone git@github.com:google-research/google-research.git + git clone git@github.com:tensorflow/models.git + mkdir google-research/tensorflow_models + touch google-research/tensorflow_models/__init__.py + cp -r models/research/audioset google-research/tensorflow_models/ + touch google-research/tensorflow_models/audioset/__init__.py + echo "from .vggish import mel_features, vggish_params, vggish_slim" > \ + google-research/tensorflow_models/audioset/__init__.py + # we can now remove the tensorflow models repository + # rm -r models + cd google-research + Follow the instructions to download the vggish checkpoint. AudioCraft base configuration + assumes it is placed in the AudioCraft reference dir. + + Note that we operate the following changes for the code to work with TensorFlow 2.X and python 3: + - Update xrange for range in: + https://github.com/google-research/google-research/blob/master/frechet_audio_distance/audioset_model.py + - Update `tf_record = tf.python_io.tf_record_iterator(filename).next()` to + `tf_record = tf.python_io.tf_record_iterator(filename).__next__()` in + https://github.com/google-research/google-research/blob/master/frechet_audio_distance/fad_utils.py + - Update `import vggish_params as params` to `from . import vggish_params as params` in: + https://github.com/tensorflow/models/blob/master/research/audioset/vggish/vggish_slim.py + - Add flag to provide a given batch size for running the AudioSet model in: + https://github.com/google-research/google-research/blob/master/frechet_audio_distance/create_embeddings_main.py + ``` + flags.DEFINE_integer('batch_size', 64, + 'Number of samples in the batch for AudioSet model.') + ``` + Ensure you pass the flag to the create_embeddings_beam.create_pipeline function, adding: + `batch_size=FLAGS.batch_size` to the provided parameters. + + 2. Follow instructions for the library installation and a valid TensorFlow installation + ``` + # e.g. instructions from: https://www.tensorflow.org/install/pip + conda install -c conda-forge cudatoolkit=11.8.0 + python3 -m pip install nvidia-cudnn-cu11==8.6.0.163 tensorflow==2.12.* + mkdir -p $CONDA_PREFIX/etc/conda/activate.d + echo 'CUDNN_PATH=$(dirname $(python -c "import nvidia.cudnn;print(nvidia.cudnn.__file__)"))' \ + >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh + echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/:$CUDNN_PATH/lib' \ + >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh + source $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh + # Verify install: on a machine with GPU device + python3 -c "import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))" + ``` + + Now install frechet_audio_distance required dependencies: + ``` + # We assume we already have TensorFlow installed from the above steps + pip install apache-beam numpy scipy tf_slim + ``` + + Finally, follow remaining library instructions to ensure you have a working frechet_audio_distance setup + (you may want to specify --model_ckpt flag pointing to the model's path). + + 3. AudioCraft's FrechetAudioDistanceMetric requires 2 environment variables pointing to the python executable + and Tensorflow library path from the above installation steps: + export TF_PYTHON_EXE="" + export TF_LIBRARY_PATH="" + + e.g. assuming we have installed everything in a dedicated conda env + with python 3.10 that is currently active: + export TF_PYTHON_EXE="$CONDA_PREFIX/bin/python" + export TF_LIBRARY_PATH="$CONDA_PREFIX/lib/python3.10/site-packages/nvidia/cudnn/lib" + + Finally you may want to export the following variable: + export TF_FORCE_GPU_ALLOW_GROWTH=true + See: https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth + + You can save those environment variables in your training conda env, when currently active: + `$CONDA_PREFIX/etc/conda/activate.d/env_vars.sh` + e.g. assuming the env with TensorFlow and frechet_audio_distance install is named ac_eval, + and the training conda env is named audiocraft: + ``` + # activate training env + conda activate audiocraft + # get path to all envs + CONDA_ENV_DIR=$(dirname $CONDA_PREFIX) + # export pointers to evaluation env for using TensorFlow in FrechetAudioDistanceMetric + touch $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh + echo 'export TF_PYTHON_EXE="$CONDA_ENV_DIR/ac_eval/bin/python"' >> \ + $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh + echo 'export TF_LIBRARY_PATH="$CONDA_ENV_DIR/ac_eval/lib/python3.10/site-packages/nvidia/cudnn/lib"' >> \ + $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh + # optionally: + echo 'export TF_FORCE_GPU_ALLOW_GROWTH=true' >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh + # you may need to reactivate the audiocraft env for this to take effect + ``` + + Args: + bin (Path or str): Path to installed frechet audio distance code. + model_path (Path or str): Path to Tensorflow checkpoint for the model + used to compute statistics over the embedding beams. + format (str): Audio format used to save files. + log_folder (Path or str, optional): Path where to write process logs. + """ + def __init__(self, bin: tp.Union[Path, str], model_path: tp.Union[Path, str], + format: str = "wav", batch_size: tp.Optional[int] = None, + log_folder: tp.Optional[tp.Union[Path, str]] = None): + super().__init__() + self.model_sample_rate = VGGISH_SAMPLE_RATE + self.model_channels = VGGISH_CHANNELS + self.model_path = AudioCraftEnvironment.resolve_reference_path(model_path) + assert Path(self.model_path).exists(), f"Could not find provided model checkpoint path at: {self.model_path}" + self.format = format + self.batch_size = batch_size + self.bin = bin + self.tf_env = {"PYTHONPATH": str(self.bin)} + self.python_path = os.environ.get('TF_PYTHON_EXE') or 'python' + logger.info("Python exe for TF is %s", self.python_path) + if 'TF_LIBRARY_PATH' in os.environ: + self.tf_env['LD_LIBRARY_PATH'] = os.environ['TF_LIBRARY_PATH'] + if 'TF_FORCE_GPU_ALLOW_GROWTH' in os.environ: + self.tf_env['TF_FORCE_GPU_ALLOW_GROWTH'] = os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] + logger.info("Env for TF is %r", self.tf_env) + self.reset(log_folder) + self.add_state("total_files", default=torch.tensor(0.), dist_reduce_fx="sum") + + def reset(self, log_folder: tp.Optional[tp.Union[Path, str]] = None): + """Reset torchmetrics.Metrics state.""" + log_folder = Path(log_folder or tempfile.mkdtemp()) + self.tmp_dir = log_folder / 'fad' + self.tmp_dir.mkdir(exist_ok=True) + self.samples_tests_dir = self.tmp_dir / 'tests' + self.samples_tests_dir.mkdir(exist_ok=True) + self.samples_background_dir = self.tmp_dir / 'background' + self.samples_background_dir.mkdir(exist_ok=True) + self.manifest_tests = self.tmp_dir / 'files_tests.cvs' + self.manifest_background = self.tmp_dir / 'files_background.cvs' + self.stats_tests_dir = self.tmp_dir / 'stats_tests' + self.stats_background_dir = self.tmp_dir / 'stats_background' + self.counter = 0 + + def update(self, preds: torch.Tensor, targets: torch.Tensor, + sizes: torch.Tensor, sample_rates: torch.Tensor, + stems: tp.Optional[tp.List[str]] = None): + """Update torchmetrics.Metrics by saving the audio and updating the manifest file.""" + assert preds.shape == targets.shape, f"preds={preds.shape} != targets={targets.shape}" + num_samples = preds.shape[0] + assert num_samples == sizes.size(0) and num_samples == sample_rates.size(0) + assert stems is None or num_samples == len(set(stems)) + for i in range(num_samples): + self.total_files += 1 # type: ignore + self.counter += 1 + wav_len = int(sizes[i].item()) + sample_rate = int(sample_rates[i].item()) + pred_wav = preds[i] + target_wav = targets[i] + pred_wav = pred_wav[..., :wav_len] + target_wav = target_wav[..., :wav_len] + stem_name = stems[i] if stems is not None else f'sample_{self.counter}_{flashy.distrib.rank()}' + # dump audio files + try: + pred_wav = convert_audio( + pred_wav.unsqueeze(0), from_rate=sample_rate, + to_rate=self.model_sample_rate, to_channels=1).squeeze(0) + audio_write( + self.samples_tests_dir / stem_name, pred_wav, sample_rate=self.model_sample_rate, + format=self.format, strategy="peak") + except Exception as e: + logger.error(f"Exception occured when saving tests files for FAD computation: {repr(e)} - {e}") + try: + # for the ground truth audio, we enforce the 'peak' strategy to avoid modifying + # the original audio when writing it + target_wav = convert_audio( + target_wav.unsqueeze(0), from_rate=sample_rate, + to_rate=self.model_sample_rate, to_channels=1).squeeze(0) + audio_write( + self.samples_background_dir / stem_name, target_wav, sample_rate=self.model_sample_rate, + format=self.format, strategy="peak") + except Exception as e: + logger.error(f"Exception occured when saving background files for FAD computation: {repr(e)} - {e}") + + def _get_samples_name(self, is_background: bool): + return 'background' if is_background else 'tests' + + def _create_embedding_beams(self, is_background: bool, gpu_index: tp.Optional[int] = None): + if is_background: + input_samples_dir = self.samples_background_dir + input_filename = self.manifest_background + stats_name = self.stats_background_dir + else: + input_samples_dir = self.samples_tests_dir + input_filename = self.manifest_tests + stats_name = self.stats_tests_dir + beams_name = self._get_samples_name(is_background) + log_file = self.tmp_dir / f'fad_logs_create_beams_{beams_name}.log' + + logger.info(f"Scanning samples folder to fetch list of files: {input_samples_dir}") + with open(input_filename, "w") as fout: + for path in Path(input_samples_dir).glob(f"*.{self.format}"): + fout.write(f"{str(path)}\n") + + cmd = [ + self.python_path, "-m", + "frechet_audio_distance.create_embeddings_main", + "--model_ckpt", f"{self.model_path}", + "--input_files", f"{str(input_filename)}", + "--stats", f"{str(stats_name)}", + ] + if self.batch_size is not None: + cmd += ["--batch_size", str(self.batch_size)] + logger.info(f"Launching frechet_audio_distance embeddings main method: {' '.join(cmd)} on {beams_name}") + env = os.environ + if gpu_index is not None: + env["CUDA_VISIBLE_DEVICES"] = str(gpu_index) + process = subprocess.Popen( + cmd, stdout=open(log_file, "w"), env={**env, **self.tf_env}, stderr=subprocess.STDOUT) + return process, log_file + + def _compute_fad_score(self, gpu_index: tp.Optional[int] = None): + cmd = [ + self.python_path, "-m", "frechet_audio_distance.compute_fad", + "--test_stats", f"{str(self.stats_tests_dir)}", + "--background_stats", f"{str(self.stats_background_dir)}", + ] + logger.info(f"Launching frechet_audio_distance compute fad method: {' '.join(cmd)}") + env = os.environ + if gpu_index is not None: + env["CUDA_VISIBLE_DEVICES"] = str(gpu_index) + result = subprocess.run(cmd, env={**env, **self.tf_env}, capture_output=True) + if result.returncode: + logger.error( + "Error with FAD computation from stats: \n %s \n %s", + result.stdout.decode(), result.stderr.decode() + ) + raise RuntimeError("Error while executing FAD computation from stats") + try: + # result is "FAD: (d+).(d+)" hence we remove the prefix with (d+) being one digit or more + fad_score = float(result.stdout[4:]) + return fad_score + except Exception as e: + raise RuntimeError(f"Error parsing FAD score from command stdout: {e}") + + def _log_process_result(self, returncode: int, log_file: tp.Union[Path, str], is_background: bool) -> None: + beams_name = self._get_samples_name(is_background) + if returncode: + with open(log_file, "r") as f: + error_log = f.read() + logger.error(error_log) + os._exit(1) + else: + logger.info(f"Successfully computed embedding beams on {beams_name} samples.") + + def _parallel_create_embedding_beams(self, num_of_gpus: int): + assert num_of_gpus > 0 + logger.info("Creating embeddings beams in a parallel manner on different GPUs") + tests_beams_process, tests_beams_log_file = self._create_embedding_beams(is_background=False, gpu_index=0) + bg_beams_process, bg_beams_log_file = self._create_embedding_beams(is_background=True, gpu_index=1) + tests_beams_code = tests_beams_process.wait() + bg_beams_code = bg_beams_process.wait() + self._log_process_result(tests_beams_code, tests_beams_log_file, is_background=False) + self._log_process_result(bg_beams_code, bg_beams_log_file, is_background=True) + + def _sequential_create_embedding_beams(self): + logger.info("Creating embeddings beams in a sequential manner") + tests_beams_process, tests_beams_log_file = self._create_embedding_beams(is_background=False) + tests_beams_code = tests_beams_process.wait() + self._log_process_result(tests_beams_code, tests_beams_log_file, is_background=False) + bg_beams_process, bg_beams_log_file = self._create_embedding_beams(is_background=True) + bg_beams_code = bg_beams_process.wait() + self._log_process_result(bg_beams_code, bg_beams_log_file, is_background=True) + + @flashy.distrib.rank_zero_only + def _local_compute_frechet_audio_distance(self): + """Compute Frechet Audio Distance score calling TensorFlow API.""" + num_of_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 + if num_of_gpus > 1: + self._parallel_create_embedding_beams(num_of_gpus) + else: + self._sequential_create_embedding_beams() + fad_score = self._compute_fad_score(gpu_index=0) + return fad_score + + def compute(self) -> float: + """Compute metrics.""" + assert self.total_files.item() > 0, "No files dumped for FAD computation!" # type: ignore + fad_score = self._local_compute_frechet_audio_distance() + logger.warning(f"FAD score = {fad_score}") + fad_score = flashy.distrib.broadcast_object(fad_score, src=0) + return fad_score diff --git a/audiocraft/metrics/kld.py b/audiocraft/metrics/kld.py new file mode 100644 index 0000000000000000000000000000000000000000..ebbbcda09b0419be4d51ae6698292ff7221e47e6 --- /dev/null +++ b/audiocraft/metrics/kld.py @@ -0,0 +1,220 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +from functools import partial +import logging +import os +import typing as tp + +import torch +import torchmetrics + +from ..data.audio_utils import convert_audio + + +logger = logging.getLogger(__name__) + + +class _patch_passt_stft: + """Decorator to patch torch.stft in PaSST.""" + def __init__(self): + self.old_stft = torch.stft + + def __enter__(self): + # return_complex is a mandatory parameter in latest torch versions + # torch is throwing RuntimeErrors when not set + torch.stft = partial(torch.stft, return_complex=False) + + def __exit__(self, *exc): + torch.stft = self.old_stft + + +def kl_divergence(pred_probs: torch.Tensor, target_probs: torch.Tensor, epsilon: float = 1e-6) -> torch.Tensor: + """Computes the elementwise KL-Divergence loss between probability distributions + from generated samples and target samples. + + Args: + pred_probs (torch.Tensor): Probabilities for each label obtained + from a classifier on generated audio. Expected shape is [B, num_classes]. + target_probs (torch.Tensor): Probabilities for each label obtained + from a classifier on target audio. Expected shape is [B, num_classes]. + epsilon (float): Epsilon value. + Returns: + kld (torch.Tensor): KLD loss between each generated sample and target pair. + """ + kl_div = torch.nn.functional.kl_div((pred_probs + epsilon).log(), target_probs, reduction="none") + return kl_div.sum(-1) + + +class KLDivergenceMetric(torchmetrics.Metric): + """Base implementation for KL Divergence metric. + + The KL divergence is measured between probability distributions + of class predictions returned by a pre-trained audio classification model. + When the KL-divergence is low, the generated audio is expected to + have similar acoustic characteristics as the reference audio, + according to the classifier. + """ + def __init__(self): + super().__init__() + self.add_state("kld_pq_sum", default=torch.tensor(0.), dist_reduce_fx="sum") + self.add_state("kld_qp_sum", default=torch.tensor(0.), dist_reduce_fx="sum") + self.add_state("kld_all_sum", default=torch.tensor(0.), dist_reduce_fx="sum") + self.add_state("weight", default=torch.tensor(0), dist_reduce_fx="sum") + + def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor, + sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]: + """Get model output given provided input tensor. + + Args: + x (torch.Tensor): Input audio tensor of shape [B, C, T]. + sizes (torch.Tensor): Actual audio sample length, of shape [B]. + sample_rates (torch.Tensor): Actual audio sample rate, of shape [B]. + Returns: + probs (torch.Tensor): Probabilities over labels, of shape [B, num_classes]. + """ + raise NotImplementedError("implement method to extract label distributions from the model.") + + def update(self, preds: torch.Tensor, targets: torch.Tensor, + sizes: torch.Tensor, sample_rates: torch.Tensor) -> None: + """Calculates running KL-Divergence loss between batches of audio + preds (generated) and target (ground-truth) + Args: + preds (torch.Tensor): Audio samples to evaluate, of shape [B, C, T]. + targets (torch.Tensor): Target samples to compare against, of shape [B, C, T]. + sizes (torch.Tensor): Actual audio sample length, of shape [B]. + sample_rates (torch.Tensor): Actual audio sample rate, of shape [B]. + """ + assert preds.shape == targets.shape + assert preds.size(0) > 0, "Cannot update the loss with empty tensors" + preds_probs = self._get_label_distribution(preds, sizes, sample_rates) + targets_probs = self._get_label_distribution(targets, sizes, sample_rates) + if preds_probs is not None and targets_probs is not None: + assert preds_probs.shape == targets_probs.shape + kld_scores = kl_divergence(preds_probs, targets_probs) + assert not torch.isnan(kld_scores).any(), "kld_scores contains NaN value(s)!" + self.kld_pq_sum += torch.sum(kld_scores) + kld_qp_scores = kl_divergence(targets_probs, preds_probs) + self.kld_qp_sum += torch.sum(kld_qp_scores) + self.weight += torch.tensor(kld_scores.size(0)) + + def compute(self) -> dict: + """Computes KL-Divergence across all evaluated pred/target pairs.""" + weight: float = float(self.weight.item()) # type: ignore + assert weight > 0, "Unable to compute with total number of comparisons <= 0" + logger.info(f"Computing KL divergence on a total of {weight} samples") + kld_pq = self.kld_pq_sum.item() / weight # type: ignore + kld_qp = self.kld_qp_sum.item() / weight # type: ignore + kld_both = kld_pq + kld_qp + return {'kld': kld_pq, 'kld_pq': kld_pq, 'kld_qp': kld_qp, 'kld_both': kld_both} + + +class PasstKLDivergenceMetric(KLDivergenceMetric): + """KL-Divergence metric based on pre-trained PASST classifier on AudioSet. + + From: PaSST: Efficient Training of Audio Transformers with Patchout + Paper: https://arxiv.org/abs/2110.05069 + Implementation: https://github.com/kkoutini/PaSST + + Follow instructions from the github repo: + ``` + pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt' + ``` + + Args: + pretrained_length (float, optional): Audio duration used for the pretrained model. + """ + def __init__(self, pretrained_length: tp.Optional[float] = None): + super().__init__() + self._initialize_model(pretrained_length) + + def _initialize_model(self, pretrained_length: tp.Optional[float] = None): + """Initialize underlying PaSST audio classifier.""" + model, sr, max_frames, min_frames = self._load_base_model(pretrained_length) + self.min_input_frames = min_frames + self.max_input_frames = max_frames + self.model_sample_rate = sr + self.model = model + self.model.eval() + self.model.to(self.device) + + def _load_base_model(self, pretrained_length: tp.Optional[float]): + """Load pretrained model from PaSST.""" + try: + if pretrained_length == 30: + from hear21passt.base30sec import get_basic_model # type: ignore + max_duration = 30 + elif pretrained_length == 20: + from hear21passt.base20sec import get_basic_model # type: ignore + max_duration = 20 + else: + from hear21passt.base import get_basic_model # type: ignore + # Original PASST was trained on AudioSet with 10s-long audio samples + max_duration = 10 + min_duration = 0.15 + min_duration = 0.15 + except ModuleNotFoundError: + raise ModuleNotFoundError( + "Please install hear21passt to compute KL divergence: ", + "pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'" + ) + model_sample_rate = 32_000 + max_input_frames = int(max_duration * model_sample_rate) + min_input_frames = int(min_duration * model_sample_rate) + with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f): + model = get_basic_model(mode='logits') + return model, model_sample_rate, max_input_frames, min_input_frames + + def _process_audio(self, wav: torch.Tensor, sample_rate: int, wav_len: int) -> tp.List[torch.Tensor]: + """Process audio to feed to the pretrained model.""" + wav = wav.unsqueeze(0) + wav = wav[..., :wav_len] + wav = convert_audio(wav, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1) + wav = wav.squeeze(0) + # we don't pad but return a list of audio segments as this otherwise affects the KLD computation + segments = torch.split(wav, self.max_input_frames, dim=-1) + valid_segments = [] + for s in segments: + # ignoring too small segments that are breaking the model inference + if s.size(-1) > self.min_input_frames: + valid_segments.append(s) + return [s[None] for s in valid_segments] + + def _get_model_preds(self, wav: torch.Tensor) -> torch.Tensor: + """Run the pretrained model and get the predictions.""" + assert wav.dim() == 3, f"Unexpected number of dims for preprocessed wav: {wav.shape}" + wav = wav.mean(dim=1) + # PaSST is printing a lot of garbage that we are not interested in + with open(os.devnull, "w") as f, contextlib.redirect_stdout(f): + with torch.no_grad(), _patch_passt_stft(): + logits = self.model(wav.to(self.device)) + probs = torch.softmax(logits, dim=-1) + return probs + + def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor, + sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]: + """Get model output given provided input tensor. + + Args: + x (torch.Tensor): Input audio tensor of shape [B, C, T]. + sizes (torch.Tensor): Actual audio sample length, of shape [B]. + sample_rates (torch.Tensor): Actual audio sample rate, of shape [B]. + Returns: + probs (torch.Tensor, optional): Probabilities over labels, of shape [B, num_classes]. + """ + all_probs: tp.List[torch.Tensor] = [] + for i, wav in enumerate(x): + sample_rate = int(sample_rates[i].item()) + wav_len = int(sizes[i].item()) + wav_segments = self._process_audio(wav, sample_rate, wav_len) + for segment in wav_segments: + probs = self._get_model_preds(segment).mean(dim=0) + all_probs.append(probs) + if len(all_probs) > 0: + return torch.stack(all_probs, dim=0) + else: + return None diff --git a/audiocraft/metrics/rvm.py b/audiocraft/metrics/rvm.py new file mode 100644 index 0000000000000000000000000000000000000000..2047b6c8d5b1d58a67090b947e7e2666c3104eca --- /dev/null +++ b/audiocraft/metrics/rvm.py @@ -0,0 +1,110 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import typing as tp +import torch +from torch import nn +import torchaudio + + +def db_to_scale(volume: tp.Union[float, torch.Tensor]): + return 10 ** (volume / 20) + + +def scale_to_db(scale: torch.Tensor, min_volume: float = -120): + min_scale = db_to_scale(min_volume) + return 20 * torch.log10(scale.clamp(min=min_scale)) + + +class RelativeVolumeMel(nn.Module): + """Relative volume melspectrogram measure. + + Computes a measure of distance over two mel spectrogram that is interpretable in terms + of decibels. Given `x_ref` and `x_est` two waveforms of shape `[*, T]`, it will + first renormalize both by the ground truth of `x_ref`. + + ..Warning:: This class returns the volume of the distortion at the spectrogram level, + e.g. low negative values reflects lower distortion levels. For a SNR (like reported + in the MultiBandDiffusion paper), just take `-rvm`. + + Then it computes the mel spectrogram `z_ref` and `z_est` and compute volume of the difference + relative to the volume of `z_ref` for each time-frequency bin. It further adds some limits, e.g. + clamping the values between -25 and 25 dB (controlled by `min_relative_volume` and `max_relative_volume`) + with the goal of avoiding the loss being dominated by parts where the reference is almost silent. + Indeed, volumes in dB can take unbounded values both towards -oo and +oo, which can make the final + average metric harder to interpret. Besides, anything below -30 dB of attenuation would sound extremely + good (for a neural network output, although sound engineers typically aim for much lower attenuations). + Similarly, anything above +30 dB would just be completely missing the target, and there is no point + in measuring by exactly how much it missed it. -25, 25 is a more conservative range, but also more + in line with what neural nets currently can achieve. + + For instance, a Relative Volume Mel (RVM) score of -10 dB means that on average, the delta between + the target and reference mel-spec is 10 dB lower than the reference mel-spec value. + + The metric can be aggregated over a given frequency band in order have different insights for + different region of the spectrum. `num_aggregated_bands` controls the number of bands. + + ..Warning:: While this function is optimized for interpretability, nothing was done to ensure it + is numerically stable when computing its gradient. We thus advise against using it as a training loss. + + Args: + sample_rate (int): Sample rate of the input audio. + n_mels (int): Number of mel bands to use. + n_fft (int): Number of frequency bins for the STFT. + hop_length (int): Hop length of the STFT and the mel-spectrogram. + min_relative_volume (float): The error `z_ref - z_est` volume is given relative to + the volume of `z_ref`. If error is smaller than -25 dB of `z_ref`, then it is clamped. + max_relative_volume (float): Same as `min_relative_volume` but clamping if the error is larger than that. + max_initial_gain (float): When rescaling the audio at the very beginning, we will limit the gain + to that amount, to avoid rescaling near silence. Given in dB. + min_activity_volume (float): When computing the reference level from `z_ref`, will clamp low volume + bins to that amount. This is effectively our "zero" level for the reference mel-spectrogram, + and anything below that will be considered equally. + num_aggregated_bands (int): Number of bands to keep when computing the average RVM value. + For instance, a value of 3 would give 3 scores, roughly for low, mid and high freqs. + """ + def __init__(self, sample_rate: int = 24000, n_mels: int = 80, n_fft: int = 512, + hop_length: int = 128, min_relative_volume: float = -25, + max_relative_volume: float = 25, max_initial_gain: float = 25, + min_activity_volume: float = -25, + num_aggregated_bands: int = 4) -> None: + super().__init__() + self.melspec = torchaudio.transforms.MelSpectrogram( + n_mels=n_mels, n_fft=n_fft, hop_length=hop_length, + normalized=True, sample_rate=sample_rate, power=2) + self.min_relative_volume = min_relative_volume + self.max_relative_volume = max_relative_volume + self.max_initial_gain = max_initial_gain + self.min_activity_volume = min_activity_volume + self.num_aggregated_bands = num_aggregated_bands + + def forward(self, estimate: torch.Tensor, ground_truth: torch.Tensor) -> tp.Dict[str, torch.Tensor]: + """Compute RVM metric between estimate and reference samples. + + Args: + estimate (torch.Tensor): Estimate sample. + ground_truth (torch.Tensor): Reference sample. + + Returns: + dict[str, torch.Tensor]: Metrics with keys `rvm` for the overall average, and `rvm_{k}` + for the RVM over the k-th band (k=0..num_aggregated_bands - 1). + """ + min_scale = db_to_scale(-self.max_initial_gain) + std = ground_truth.pow(2).mean().sqrt().clamp(min=min_scale) + z_gt = self.melspec(ground_truth / std).sqrt() + z_est = self.melspec(estimate / std).sqrt() + + delta = z_gt - z_est + ref_db = scale_to_db(z_gt, self.min_activity_volume) + delta_db = scale_to_db(delta.abs(), min_volume=-120) + relative_db = (delta_db - ref_db).clamp(self.min_relative_volume, self.max_relative_volume) + dims = list(range(relative_db.dim())) + dims.remove(dims[-2]) + losses_per_band = relative_db.mean(dim=dims) + aggregated = [chunk.mean() for chunk in losses_per_band.chunk(self.num_aggregated_bands, dim=0)] + metrics = {f'rvm_{index}': value for index, value in enumerate(aggregated)} + metrics['rvm'] = losses_per_band.mean() + return metrics diff --git a/audiocraft/metrics/visqol.py b/audiocraft/metrics/visqol.py new file mode 100644 index 0000000000000000000000000000000000000000..44f4b0a2c3c6c726857db8386491823dd85dde51 --- /dev/null +++ b/audiocraft/metrics/visqol.py @@ -0,0 +1,216 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import csv +import json +import logging +from pathlib import Path +import tempfile +import typing as tp +import subprocess +import shutil + +import torch +import torchaudio + +logger = logging.getLogger(__name__) + + +class ViSQOL: + """ViSQOL wrapper to run ViSQOL from Python using a pre-installed binary. + + To learn more about ViSQOL and how to build ViSQOL binary using bazel, please refer to the + instructions available in the open source repository: https://github.com/google/visqol + + ViSQOL is capable of running in two modes: + + Audio Mode: + When running in audio mode, input signals must have a 48kHz sample rate. Input should be resampled to 48kHz. + Input signals can be multi-channel, but they will be down-mixed to mono for performing the comparison. + Audio mode uses support vector regression, with the maximum range at ~4.75. + + Speech Mode: + When running in speech mode, ViSQOL uses a wideband model. It therefore expects input sample rates of 16kHz. + Input should be resampled to 16kHz. + As part of the speech mode processing, a root mean square implementation for voice activity detection + is performed on the reference signal to determine what parts of the signal have voice activity and + should therefore be included in the comparison. The signal is normalized before performing the voice + activity detection. + Input signals can be multi-channel, but they will be down-mixed to mono for performing the comparison. + Speech mode is scaled to have a maximum MOS of 5.0 to match previous version behavior. + + For more details, check the guidelines: https://github.com/google/visqol#general-guidelines-for-input + + Args: + visqol_bin (str): Path to the ViSQOL binary. + mode (str): ViSQOL computation mode, expecting "audio" or "speech". + model (str): Name of the model to use for similarity to quality model. + debug (bool): Whether to also get debug metrics from ViSQOL or not. + """ + SAMPLE_RATES_MODES = {"audio": 48_000, "speech": 16_000} + ALLOWED_SAMPLE_RATES = frozenset(SAMPLE_RATES_MODES.values()) + + def __init__(self, bin: tp.Union[Path, str], mode: str = "audio", + model: str = "libsvm_nu_svr_model.txt", debug: bool = False): + assert bin is not None and Path(bin).exists(), f"Could not find ViSQOL binary in specified path: {bin}" + self.visqol_bin = str(bin) + self.visqol_mode = mode + self.target_sr = self._get_target_sr(self.visqol_mode) + self.model = model + self.debug = debug + assert Path(self.visqol_model).exists(), \ + f"Could not find the specified model in ViSQOL install: {self.visqol_model}" + + def _get_target_sr(self, mode: str) -> int: + # returns target sampling rate for the corresponding ViSQOL mode. + if mode not in ViSQOL.SAMPLE_RATES_MODES: + raise ValueError( + f"Unsupported mode! Allowed are: {', '.join(ViSQOL.SAMPLE_RATES_MODES.keys())}" + ) + return ViSQOL.SAMPLE_RATES_MODES[mode] + + def _prepare_files( + self, ref_sig: torch.Tensor, deg_sig: torch.Tensor, sr: int, target_sr: int, pad_with_silence: bool = False + ): + # prepare files for ViSQOL evaluation. + assert target_sr in ViSQOL.ALLOWED_SAMPLE_RATES + assert len(ref_sig) == len(deg_sig), ( + "Expects same number of ref and degraded inputs", + f" but ref len {len(ref_sig)} != deg len {len(deg_sig)}" + ) + # resample audio if needed + if sr != target_sr: + transform = torchaudio.transforms.Resample(sr, target_sr) + pad = int(0.5 * target_sr) + rs_ref = [] + rs_deg = [] + for i in range(len(ref_sig)): + rs_ref_i = transform(ref_sig[i]) + rs_deg_i = transform(deg_sig[i]) + if pad_with_silence: + rs_ref_i = torch.nn.functional.pad(rs_ref_i, (pad, pad), mode='constant', value=0) + rs_deg_i = torch.nn.functional.pad(rs_deg_i, (pad, pad), mode='constant', value=0) + rs_ref.append(rs_ref_i) + rs_deg.append(rs_deg_i) + ref_sig = torch.stack(rs_ref) + deg_sig = torch.stack(rs_deg) + # save audio chunks to tmp dir and create csv + tmp_dir = Path(tempfile.mkdtemp()) + try: + tmp_input_csv_path = tmp_dir / "input.csv" + tmp_results_csv_path = tmp_dir / "results.csv" + tmp_debug_json_path = tmp_dir / "debug.json" + with open(tmp_input_csv_path, "w") as csv_file: + csv_writer = csv.writer(csv_file) + csv_writer.writerow(["reference", "degraded"]) + for i in range(len(ref_sig)): + tmp_ref_filename = tmp_dir / f"ref_{i}.wav" + tmp_deg_filename = tmp_dir / f"deg_{i}.wav" + torchaudio.save( + tmp_ref_filename, + torch.clamp(ref_sig[i], min=-0.99, max=0.99), + sample_rate=target_sr, + bits_per_sample=16, + encoding="PCM_S" + ) + torchaudio.save( + tmp_deg_filename, + torch.clamp(deg_sig[i], min=-0.99, max=0.99), + sample_rate=target_sr, + bits_per_sample=16, + encoding="PCM_S" + ) + csv_writer.writerow([str(tmp_ref_filename), str(tmp_deg_filename)]) + return tmp_dir, tmp_input_csv_path, tmp_results_csv_path, tmp_debug_json_path + except Exception as e: + logger.error("Exception occurred when preparing files for ViSQOL: %s", e) + return tmp_dir, None, None, None + + def _flush_files(self, tmp_dir: tp.Union[Path, str]): + # flush tmp files used to compute ViSQOL. + shutil.rmtree(str(tmp_dir)) + + def _collect_moslqo_score(self, results_csv_path: tp.Union[Path, str]) -> float: + # collect results for each evaluated pair and return averaged moslqo score. + with open(results_csv_path, "r") as csv_file: + reader = csv.DictReader(csv_file) + moslqo_scores = [float(row["moslqo"]) for row in reader] + if len(moslqo_scores) > 0: + return sum(moslqo_scores) / len(moslqo_scores) + else: + return 0.0 + + def _collect_debug_data(self, debug_json_path: tp.Union[Path, str]) -> dict: + # collect debug data for the visqol inference. + with open(debug_json_path, "r") as f: + data = json.load(f) + return data + + @property + def visqol_model(self): + return f'{self.visqol_bin}/model/{self.model}' + + def _run_visqol( + self, + input_csv_path: tp.Union[Path, str], + results_csv_path: tp.Union[Path, str], + debug_csv_path: tp.Optional[tp.Union[Path, str]], + ): + input_csv_path = str(input_csv_path) + results_csv_path = str(results_csv_path) + debug_csv_path = str(debug_csv_path) + cmd = [ + f'{self.visqol_bin}/bazel-bin/visqol', + '--batch_input_csv', f'{input_csv_path}', + '--results_csv', f'{results_csv_path}' + ] + if debug_csv_path is not None: + cmd += ['--output_debug', f'{debug_csv_path}'] + if self.visqol_mode == "speech": + cmd += ['--use_speech_mode'] + cmd += ['--similarity_to_quality_model', f'{self.visqol_model}'] + result = subprocess.run(cmd, capture_output=True) + if result.returncode: + logger.error("Error with visqol: \n %s \n %s", result.stdout.decode(), result.stderr.decode()) + raise RuntimeError("Error while executing visqol") + result.check_returncode() + + def __call__( + self, + ref_sig: torch.Tensor, + deg_sig: torch.Tensor, + sr: int, + pad_with_silence: bool = False, + ): + """Calculate the ViSQOL metric for a pair of audio signals at a given sample rate. + Args: + ref_sig (torch.Tensor): Reference signals as [B, C, T]. + deg_sig (torch.Tensor): Degraded signals as [B, C, T]. + sr (int): Sample rate of the two audio signals. + pad_with_silence (bool): Whether to pad the file with silences as recommended + in visqol guidelines (see: https://github.com/google/visqol#general-guidelines-for-input). + Returns: + float: The ViSQOL score or mean score for the batch. + """ + logger.debug(f"Calculating visqol with mode={self.visqol_mode} on {len(ref_sig)} samples") + tmp_dir, input_csv, results_csv, debug_json = self._prepare_files( + ref_sig, deg_sig, sr, self.target_sr, pad_with_silence + ) + try: + if input_csv and results_csv: + self._run_visqol( + input_csv, + results_csv, + debug_json if self.debug else None, + ) + mosqol = self._collect_moslqo_score(results_csv) + return mosqol + else: + raise RuntimeError("Something unexpected happened when running VISQOL!") + except Exception as e: + logger.error("Exception occurred when running ViSQOL: %s", e) + finally: + self._flush_files(tmp_dir) diff --git a/audiocraft/models/__init__.py b/audiocraft/models/__init__.py index 92c7a48a200eba455044cd66e0d2c1efe6494f5c..be6bfe4b787a132aeaabaed1c3437c9ecd5c656c 100644 --- a/audiocraft/models/__init__.py +++ b/audiocraft/models/__init__.py @@ -3,8 +3,16 @@ # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. - +""" +Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel. +""" # flake8: noqa -from .musicgen import MusicGen +from . import builders, loaders +from .encodec import ( + CompressionModel, EncodecModel, DAC, + HFEncodecModel, HFEncodecCompressionModel) +from .audiogen import AudioGen from .lm import LMModel -from .encodec import CompressionModel, EncodecModel +from .multibanddiffusion import MultiBandDiffusion +from .musicgen import MusicGen +from .unet import DiffusionUnet diff --git a/audiocraft/models/audiogen.py b/audiocraft/models/audiogen.py new file mode 100644 index 0000000000000000000000000000000000000000..b4df536eddd3657cc8cc6bf846606a16df8985eb --- /dev/null +++ b/audiocraft/models/audiogen.py @@ -0,0 +1,267 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Main model for using AudioGen. This will combine all the required components +and provide easy access to the generation API. +""" + +import typing as tp + +import torch + +from .encodec import CompressionModel +from .lm import LMModel +from .builders import get_debug_compression_model, get_debug_lm_model +from .loaders import load_compression_model, load_lm_model +from ..data.audio_utils import convert_audio +from ..modules.conditioners import ConditioningAttributes +from ..utils.autocast import TorchAutocast + + +class AudioGen: + """AudioGen main model with convenient generation API. + + Args: + name (str): name of the model. + compression_model (CompressionModel): Compression model + used to map audio to invertible discrete representations. + lm (LMModel): Language model over discrete representations. + max_duration (float, optional): maximum duration the model can produce, + otherwise, inferred from the training params. + """ + def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel, + max_duration: tp.Optional[float] = None): + self.name = name + self.compression_model = compression_model + self.lm = lm + # Just to be safe, let's put everything in eval mode. + self.compression_model.eval() + self.lm.eval() + + if max_duration is None: + if hasattr(lm, 'cfg'): + max_duration = lm.cfg.dataset.segment_duration # type: ignore + else: + raise ValueError("You must provide max_duration when building directly AudioGen") + assert max_duration is not None + self.max_duration: float = max_duration + self.device = next(iter(lm.parameters())).device + self.generation_params: dict = {} + self.set_generation_params(duration=5) # 5 seconds by default + self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None + if self.device.type == 'cpu': + self.autocast = TorchAutocast(enabled=False) + else: + self.autocast = TorchAutocast( + enabled=True, device_type=self.device.type, dtype=torch.float16) + + @property + def frame_rate(self) -> float: + """Roughly the number of AR steps per seconds.""" + return self.compression_model.frame_rate + + @property + def sample_rate(self) -> int: + """Sample rate of the generated audio.""" + return self.compression_model.sample_rate + + @property + def audio_channels(self) -> int: + """Audio channels of the generated audio.""" + return self.compression_model.channels + + @staticmethod + def get_pretrained(name: str = 'facebook/audiogen-medium', device=None): + """Return pretrained model, we provide a single model for now: + - facebook/audiogen-medium (1.5B), text to sound, + # see: https://huggingface.co./facebook/audiogen-medium + """ + if device is None: + if torch.cuda.device_count(): + device = 'cuda' + else: + device = 'cpu' + + if name == 'debug': + # used only for unit tests + compression_model = get_debug_compression_model(device, sample_rate=16000) + lm = get_debug_lm_model(device) + return AudioGen(name, compression_model, lm, max_duration=10) + + compression_model = load_compression_model(name, device=device) + lm = load_lm_model(name, device=device) + assert 'self_wav' not in lm.condition_provider.conditioners, \ + "AudioGen do not support waveform conditioning for now" + return AudioGen(name, compression_model, lm) + + def set_generation_params(self, use_sampling: bool = True, top_k: int = 250, + top_p: float = 0.0, temperature: float = 1.0, + duration: float = 10.0, cfg_coef: float = 3.0, + two_step_cfg: bool = False, extend_stride: float = 2): + """Set the generation parameters for AudioGen. + + Args: + use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True. + top_k (int, optional): top_k used for sampling. Defaults to 250. + top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0. + temperature (float, optional): Softmax temperature parameter. Defaults to 1.0. + duration (float, optional): Duration of the generated waveform. Defaults to 10.0. + cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0. + two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance, + instead of batching together the two. This has some impact on how things + are padded but seems to have little impact in practice. + extend_stride: when doing extended generation (i.e. more than 10 seconds), by how much + should we extend the audio each time. Larger values will mean less context is + preserved, and shorter value will require extra computations. + """ + assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration." + self.extend_stride = extend_stride + self.duration = duration + self.generation_params = { + 'use_sampling': use_sampling, + 'temp': temperature, + 'top_k': top_k, + 'top_p': top_p, + 'cfg_coef': cfg_coef, + 'two_step_cfg': two_step_cfg, + } + + def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None): + """Override the default progress callback.""" + self._progress_callback = progress_callback + + def generate(self, descriptions: tp.List[str], progress: bool = False) -> torch.Tensor: + """Generate samples conditioned on text. + + Args: + descriptions (list of str): A list of strings used as text conditioning. + progress (bool, optional): Flag to display progress of the generation process. Defaults to False. + """ + attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None) + assert prompt_tokens is None + return self._generate_tokens(attributes, prompt_tokens, progress) + + def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int, + descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None, + progress: bool = False) -> torch.Tensor: + """Generate samples conditioned on audio prompts. + + Args: + prompt (torch.Tensor): A batch of waveforms used for continuation. + Prompt should be [B, C, T], or [C, T] if only one sample is generated. + prompt_sample_rate (int): Sampling rate of the given audio waveforms. + descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None. + progress (bool, optional): Flag to display progress of the generation process. Defaults to False. + """ + if prompt.dim() == 2: + prompt = prompt[None] + if prompt.dim() != 3: + raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).") + prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels) + if descriptions is None: + descriptions = [None] * len(prompt) + attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt) + assert prompt_tokens is not None + return self._generate_tokens(attributes, prompt_tokens, progress) + + @torch.no_grad() + def _prepare_tokens_and_attributes( + self, + descriptions: tp.Sequence[tp.Optional[str]], + prompt: tp.Optional[torch.Tensor], + ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]: + """Prepare model inputs. + + Args: + descriptions (list of str): A list of strings used as text conditioning. + prompt (torch.Tensor): A batch of waveforms used for continuation. + """ + attributes = [ + ConditioningAttributes(text={'description': description}) + for description in descriptions] + + if prompt is not None: + if descriptions is not None: + assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match" + prompt = prompt.to(self.device) + prompt_tokens, scale = self.compression_model.encode(prompt) + assert scale is None + else: + prompt_tokens = None + return attributes, prompt_tokens + + def _generate_tokens(self, attributes: tp.List[ConditioningAttributes], + prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor: + """Generate discrete audio tokens given audio prompt and/or conditions. + + Args: + attributes (list of ConditioningAttributes): Conditions used for generation (here text). + prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation. + progress (bool, optional): Flag to display progress of the generation process. Defaults to False. + Returns: + torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params. + """ + total_gen_len = int(self.duration * self.frame_rate) + max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate) + current_gen_offset: int = 0 + + def _progress_callback(generated_tokens: int, tokens_to_generate: int): + generated_tokens += current_gen_offset + if self._progress_callback is not None: + # Note that total_gen_len might be quite wrong depending on the + # codebook pattern used, but with delay it is almost accurate. + self._progress_callback(generated_tokens, total_gen_len) + else: + print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r') + + if prompt_tokens is not None: + assert max_prompt_len >= prompt_tokens.shape[-1], \ + "Prompt is longer than audio to generate" + + callback = None + if progress: + callback = _progress_callback + + if self.duration <= self.max_duration: + # generate by sampling from LM, simple case. + with self.autocast: + gen_tokens = self.lm.generate( + prompt_tokens, attributes, + callback=callback, max_gen_len=total_gen_len, **self.generation_params) + + else: + all_tokens = [] + if prompt_tokens is None: + prompt_length = 0 + else: + all_tokens.append(prompt_tokens) + prompt_length = prompt_tokens.shape[-1] + + stride_tokens = int(self.frame_rate * self.extend_stride) + while current_gen_offset + prompt_length < total_gen_len: + time_offset = current_gen_offset / self.frame_rate + chunk_duration = min(self.duration - time_offset, self.max_duration) + max_gen_len = int(chunk_duration * self.frame_rate) + with self.autocast: + gen_tokens = self.lm.generate( + prompt_tokens, attributes, + callback=callback, max_gen_len=max_gen_len, **self.generation_params) + if prompt_tokens is None: + all_tokens.append(gen_tokens) + else: + all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:]) + prompt_tokens = gen_tokens[:, :, stride_tokens:] + prompt_length = prompt_tokens.shape[-1] + current_gen_offset += stride_tokens + + gen_tokens = torch.cat(all_tokens, dim=-1) + + # generate audio + assert gen_tokens.dim() == 3 + with torch.no_grad(): + gen_audio = self.compression_model.decode(gen_tokens, None) + return gen_audio diff --git a/audiocraft/models/builders.py b/audiocraft/models/builders.py index 77ee5f96fea2e3c9e475fe961bc1a5ee473ed8eb..b7144874457e569d6e25fe30cafa0cddc1dd59a1 100644 --- a/audiocraft/models/builders.py +++ b/audiocraft/models/builders.py @@ -10,32 +10,34 @@ from the Hydra config. """ import typing as tp -import warnings import audiocraft import omegaconf import torch -from .encodec import CompressionModel, EncodecModel, FlattenedCompressionModel # noqa +from .encodec import CompressionModel, EncodecModel, InterleaveStereoCompressionModel from .lm import LMModel from ..modules.codebooks_patterns import ( CodebooksPatternProvider, DelayedPatternProvider, + MusicLMPattern, ParallelPatternProvider, UnrolledPatternProvider, - VALLEPattern, - MusicLMPattern, + CoarseFirstPattern, ) from ..modules.conditioners import ( BaseConditioner, + ChromaStemConditioner, + CLAPEmbeddingConditioner, + ConditionFuser, ConditioningProvider, LUTConditioner, T5Conditioner, - ConditionFuser, - ChromaStemConditioner, ) +from .unet import DiffusionUnet from .. import quantization as qt from ..utils.utils import dict_from_config +from ..modules.diffusion_schedule import MultiBandProcessor, SampleProcessor def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer: @@ -60,12 +62,11 @@ def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig): decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs) return encoder, decoder else: - raise KeyError(f'Unexpected compression model {cfg.compression_model}') + raise KeyError(f"Unexpected compression model {cfg.compression_model}") def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel: - """Instantiate a compression model. - """ + """Instantiate a compression model.""" if cfg.compression_model == 'encodec': kwargs = dict_from_config(getattr(cfg, 'encodec')) encoder_name = kwargs.pop('autoencoder') @@ -73,20 +74,17 @@ def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel: encoder, decoder = get_encodec_autoencoder(encoder_name, cfg) quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension) frame_rate = kwargs['sample_rate'] // encoder.hop_length - renormalize = kwargs.pop('renormalize', None) - renorm = kwargs.pop('renorm') - if renormalize is None: - renormalize = renorm is not None - warnings.warn("You are using a deprecated EnCodec model. Please migrate to new renormalization.") + renormalize = kwargs.pop('renormalize', False) + # deprecated params + kwargs.pop('renorm', None) return EncodecModel(encoder, decoder, quantizer, frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device) else: - raise KeyError(f'Unexpected compression model {cfg.compression_model}') + raise KeyError(f"Unexpected compression model {cfg.compression_model}") def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel: - """Instantiate a transformer LM. - """ + """Instantiate a transformer LM.""" if cfg.lm_model == 'transformer_lm': kwargs = dict_from_config(getattr(cfg, 'transformer_lm')) n_q = kwargs['n_q'] @@ -94,14 +92,14 @@ def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel: codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern') attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout')) cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance')) - cfg_prob, cfg_coef = cls_free_guidance["training_dropout"], cls_free_guidance["inference_coef"] + cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef'] fuser = get_condition_fuser(cfg) condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device) - if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programatically + if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically kwargs['cross_attention'] = True if codebooks_pattern_cfg.modeling is None: assert q_modeling is not None, \ - 'LM model should either have a codebook pattern defined or transformer_lm.q_modeling' + "LM model should either have a codebook pattern defined or transformer_lm.q_modeling" codebooks_pattern_cfg = omegaconf.OmegaConf.create( {'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}} ) @@ -118,45 +116,50 @@ def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel: **kwargs ).to(cfg.device) else: - raise KeyError(f'Unexpected LM model {cfg.lm_model}') + raise KeyError(f"Unexpected LM model {cfg.lm_model}") def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider: - """Instantiate a conditioning model. - """ + """Instantiate a conditioning model.""" device = cfg.device duration = cfg.dataset.segment_duration - cfg = getattr(cfg, "conditioners") - cfg = omegaconf.OmegaConf.create({}) if cfg is None else cfg + cfg = getattr(cfg, 'conditioners') + dict_cfg = {} if cfg is None else dict_from_config(cfg) conditioners: tp.Dict[str, BaseConditioner] = {} - with omegaconf.open_dict(cfg): - condition_provider_args = cfg.pop('args', {}) - for cond, cond_cfg in cfg.items(): - model_type = cond_cfg["model"] + condition_provider_args = dict_cfg.pop('args', {}) + condition_provider_args.pop('merge_text_conditions_p', None) + condition_provider_args.pop('drop_desc_p', None) + + for cond, cond_cfg in dict_cfg.items(): + model_type = cond_cfg['model'] model_args = cond_cfg[model_type] - if model_type == "t5": + if model_type == 't5': conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args) - elif model_type == "lut": + elif model_type == 'lut': conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args) - elif model_type == "chroma_stem": - model_args.pop('cache_path', None) + elif model_type == 'chroma_stem': conditioners[str(cond)] = ChromaStemConditioner( output_dim=output_dim, duration=duration, device=device, **model_args ) + elif model_type == 'clap': + conditioners[str(cond)] = CLAPEmbeddingConditioner( + output_dim=output_dim, + device=device, + **model_args + ) else: - raise ValueError(f"unrecognized conditioning model: {model_type}") + raise ValueError(f"Unrecognized conditioning model: {model_type}") conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args) return conditioner def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser: - """Instantiate a condition fuser object. - """ - fuser_cfg = getattr(cfg, "fuser") - fuser_methods = ["sum", "cross", "prepend", "input_interpolate"] + """Instantiate a condition fuser object.""" + fuser_cfg = getattr(cfg, 'fuser') + fuser_methods = ['sum', 'cross', 'prepend', 'input_interpolate'] fuse2cond = {k: fuser_cfg[k] for k in fuser_methods} kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods} fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs) @@ -164,13 +167,12 @@ def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser: def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider: - """Instantiate a codebooks pattern provider object. - """ + """Instantiate a codebooks pattern provider object.""" pattern_providers = { 'parallel': ParallelPatternProvider, 'delay': DelayedPatternProvider, 'unroll': UnrolledPatternProvider, - 'valle': VALLEPattern, + 'coarse_first': CoarseFirstPattern, 'musiclm': MusicLMPattern, } name = cfg.modeling @@ -179,14 +181,20 @@ def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> Codeb return klass(n_q, **kwargs) -def get_debug_compression_model(device='cpu'): - """Instantiate a debug compression model to be used for unit tests. - """ - seanet_kwargs = { +def get_debug_compression_model(device='cpu', sample_rate: int = 32000): + """Instantiate a debug compression model to be used for unit tests.""" + assert sample_rate in [16000, 32000], "unsupported sample rate for debug compression model" + model_ratios = { + 16000: [10, 8, 8], # 25 Hz at 16kHz + 32000: [10, 8, 16] # 25 Hz at 32kHz + } + ratios: tp.List[int] = model_ratios[sample_rate] + frame_rate = 25 + seanet_kwargs: dict = { 'n_filters': 4, 'n_residual_layers': 1, 'dimension': 32, - 'ratios': [10, 8, 16] # 25 Hz at 32kHz + 'ratios': ratios, } encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs) decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs) @@ -195,13 +203,31 @@ def get_debug_compression_model(device='cpu'): quantizer(init_x, 1) # initialize kmeans etc. compression_model = EncodecModel( encoder, decoder, quantizer, - frame_rate=25, sample_rate=32000, channels=1).to(device) + frame_rate=frame_rate, sample_rate=sample_rate, channels=1).to(device) return compression_model.eval() +def get_diffusion_model(cfg: omegaconf.DictConfig): + # TODO Find a way to infer the channels from dset + channels = cfg.channels + num_steps = cfg.schedule.num_steps + return DiffusionUnet( + chin=channels, num_steps=num_steps, **cfg.diffusion_unet) + + +def get_processor(cfg, sample_rate: int = 24000): + sample_processor = SampleProcessor() + if cfg.use: + kw = dict(cfg) + kw.pop('use') + kw.pop('name') + if cfg.name == "multi_band_processor": + sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw) + return sample_processor + + def get_debug_lm_model(device='cpu'): - """Instantiate a debug LM to be used for unit tests. - """ + """Instantiate a debug LM to be used for unit tests.""" pattern = DelayedPatternProvider(n_q=4) dim = 16 providers = { @@ -216,3 +242,17 @@ def get_debug_lm_model(device='cpu'): n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2, cross_attention=True, causal=True) return lm.to(device).eval() + + +def get_wrapped_compression_model( + compression_model: CompressionModel, + cfg: omegaconf.DictConfig) -> CompressionModel: + if hasattr(cfg, 'interleave_stereo_codebooks'): + if cfg.interleave_stereo_codebooks.use: + kwargs = dict_from_config(cfg.interleave_stereo_codebooks) + kwargs.pop('use') + compression_model = InterleaveStereoCompressionModel(compression_model, **kwargs) + if hasattr(cfg, 'compression_model_n_q'): + if cfg.compression_model_n_q is not None: + compression_model.set_num_codebooks(cfg.compression_model_n_q) + return compression_model diff --git a/audiocraft/models/encodec.py b/audiocraft/models/encodec.py index 69621a695887b0b41614c51cae020f6fd0af221d..d4e77a941ef6b45ca54933afc6e430a75390013c 100644 --- a/audiocraft/models/encodec.py +++ b/audiocraft/models/encodec.py @@ -3,18 +3,32 @@ # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +"""Compression models or wrapper around existing models. +Also defines the main interface that a model must follow to be usable as an audio tokenizer. +""" from abc import ABC, abstractmethod +import logging +import math +from pathlib import Path import typing as tp from einops import rearrange +import numpy as np import torch from torch import nn +from transformers import EncodecModel as HFEncodecModel from .. import quantization as qt +logger = logging.getLogger() + + class CompressionModel(ABC, nn.Module): + """Base API for all compression model that aim at being used as audio tokenizers + with a language model. + """ @abstractmethod def forward(self, x: torch.Tensor) -> qt.QuantizedResult: @@ -22,12 +36,17 @@ class CompressionModel(ABC, nn.Module): @abstractmethod def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: - """See `EncodecModel.encode`""" + """See `EncodecModel.encode`.""" ... @abstractmethod def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): - """See `EncodecModel.decode`""" + """See `EncodecModel.decode`.""" + ... + + @abstractmethod + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" ... @property @@ -37,7 +56,7 @@ class CompressionModel(ABC, nn.Module): @property @abstractmethod - def frame_rate(self) -> int: + def frame_rate(self) -> float: ... @property @@ -62,10 +81,46 @@ class CompressionModel(ABC, nn.Module): @abstractmethod def set_num_codebooks(self, n: int): - """Set the active number of codebooks used by the quantizer. - """ + """Set the active number of codebooks used by the quantizer.""" ... + @staticmethod + def get_pretrained( + name: str, device: tp.Union[torch.device, str] = 'cpu' + ) -> 'CompressionModel': + """Instantiate a CompressionModel from a given pretrained model. + + Args: + name (Path or str): name of the pretrained model. See after. + device (torch.device or str): Device on which the model is loaded. + + Pretrained models: + - dac_44khz (https://github.com/descriptinc/descript-audio-codec) + - dac_24khz (same) + - facebook/encodec_24khz (https://huggingface.co./facebook/encodec_24khz) + - facebook/encodec_32khz (https://huggingface.co./facebook/encodec_32khz) + - your own model on HugginFace. Export instructions to come... + """ + + from . import builders, loaders + model: CompressionModel + if name in ['dac_44khz', 'dac_24khz']: + model_type = name.split('_')[1] + logger.info("Getting pretrained compression model from DAC %s", model_type) + model = DAC(model_type) + elif name in ['debug_compression_model']: + logger.info("Getting pretrained compression model for debug") + model = builders.get_debug_compression_model() + elif Path(name).exists(): + # We assume here if the paths exist that it is in fact an AC checkpoint + # that was exported using `audiocraft.utils.export` functions. + model = loaders.load_compression_model(name, device=device) + else: + logger.info("Getting pretrained compression model from HF %s", name) + hf_model = HFEncodecModel.from_pretrained(name) + model = HFEncodecCompressionModel(hf_model).to(device) + return model.to(device).eval() + class EncodecModel(CompressionModel): """Encodec model operating on the raw waveform. @@ -80,9 +135,9 @@ class EncodecModel(CompressionModel): causal (bool): Whether to use a causal version of the model. renormalize (bool): Whether to renormalize the audio before running the model. """ - # we need assignement to override the property in the abstract class, + # we need assignment to override the property in the abstract class, # I couldn't find a better way... - frame_rate: int = 0 + frame_rate: float = 0 sample_rate: int = 0 channels: int = 0 @@ -111,25 +166,21 @@ class EncodecModel(CompressionModel): @property def total_codebooks(self): - """Total number of quantizer codebooks available. - """ + """Total number of quantizer codebooks available.""" return self.quantizer.total_codebooks @property def num_codebooks(self): - """Active number of codebooks used by the quantizer. - """ + """Active number of codebooks used by the quantizer.""" return self.quantizer.num_codebooks def set_num_codebooks(self, n: int): - """Set the active number of codebooks used by the quantizer. - """ + """Set the active number of codebooks used by the quantizer.""" self.quantizer.set_num_codebooks(n) @property def cardinality(self): - """Cardinality of each codebook. - """ + """Cardinality of each codebook.""" return self.quantizer.bins def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: @@ -176,7 +227,7 @@ class EncodecModel(CompressionModel): x (torch.Tensor): Float tensor of shape [B, C, T] Returns: - codes, scale (tp.Tuple[torch.Tensor, torch.Tensor]): Tuple composed of: + codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of: codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep. scale a float tensor containing the scale for audio renormalizealization. """ @@ -192,41 +243,174 @@ class EncodecModel(CompressionModel): Args: codes (torch.Tensor): Int tensor of shape [B, K, T] - scale (tp.Optional[torch.Tensor]): Float tensor containing the scale value. + scale (torch.Tensor, optional): Float tensor containing the scale value. Returns: out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio. """ - emb = self.quantizer.decode(codes) + emb = self.decode_latent(codes) out = self.decoder(emb) out = self.postprocess(out, scale) # out contains extra padding added by the encoder and decoder return out + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" + return self.quantizer.decode(codes) + + +class DAC(CompressionModel): + def __init__(self, model_type: str = "44khz"): + super().__init__() + try: + import dac.utils + except ImportError: + raise RuntimeError("Could not import dac, make sure it is installed, " + "please run `pip install descript-audio-codec`") + self.model = dac.utils.load_model(model_type=model_type) + self.n_quantizers = self.total_codebooks + self.model.eval() + + def forward(self, x: torch.Tensor) -> qt.QuantizedResult: + # We don't support training with this. + raise NotImplementedError("Forward and training with DAC not supported.") + + def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + codes = self.model.encode(x, self.n_quantizers)[1] + return codes[:, :self.n_quantizers], None + + def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): + assert scale is None + z_q = self.decode_latent(codes) + return self.model.decode(z_q) + + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" + return self.model.quantizer.from_codes(codes)[0] + + @property + def channels(self) -> int: + return 1 + + @property + def frame_rate(self) -> float: + return self.model.sample_rate / self.model.hop_length + + @property + def sample_rate(self) -> int: + return self.model.sample_rate + + @property + def cardinality(self) -> int: + return self.model.codebook_size + + @property + def num_codebooks(self) -> int: + return self.n_quantizers -class FlattenedCompressionModel(CompressionModel): - """Wraps a CompressionModel and flatten its codebooks, e.g. - instead of returning [B, K, T], return [B, S, T * (K // S)] with - S the number of codebooks per step, and `K // S` the number of 'virtual steps' - for each real time step. + @property + def total_codebooks(self) -> int: + return self.model.n_codebooks + + def set_num_codebooks(self, n: int): + """Set the active number of codebooks used by the quantizer. + """ + assert n >= 1 + assert n <= self.total_codebooks + self.n_quantizers = n + + +class HFEncodecCompressionModel(CompressionModel): + """Wrapper around HuggingFace Encodec. + """ + def __init__(self, model: HFEncodecModel): + super().__init__() + self.model = model + bws = self.model.config.target_bandwidths + num_codebooks = [ + bw * 1000 / (self.frame_rate * math.log2(self.cardinality)) + for bw in bws + ] + deltas = [nc - int(nc) for nc in num_codebooks] + # Checking we didn't do some bad maths and we indeed have integers! + assert all(deltas) <= 1e-3, deltas + self.possible_num_codebooks = [int(nc) for nc in num_codebooks] + self.set_num_codebooks(max(self.possible_num_codebooks)) + + def forward(self, x: torch.Tensor) -> qt.QuantizedResult: + # We don't support training with this. + raise NotImplementedError("Forward and training with HF EncodecModel not supported.") + + def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + bandwidth_index = self.possible_num_codebooks.index(self.num_codebooks) + bandwidth = self.model.config.target_bandwidths[bandwidth_index] + res = self.model.encode(x, None, bandwidth) + assert len(res[0]) == 1 + assert len(res[1]) == 1 + return res[0][0], res[1][0] + + def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): + if scale is None: + scales = [None] # type: ignore + else: + scales = scale # type: ignore + res = self.model.decode(codes[None], scales) + return res[0] + + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" + return self.model.quantizer.decode(codes.transpose(0, 1)) + + @property + def channels(self) -> int: + return self.model.config.audio_channels + + @property + def frame_rate(self) -> float: + hop_length = int(np.prod(self.model.config.upsampling_ratios)) + return self.sample_rate / hop_length + + @property + def sample_rate(self) -> int: + return self.model.config.sampling_rate + + @property + def cardinality(self) -> int: + return self.model.config.codebook_size + + @property + def num_codebooks(self) -> int: + return self._num_codebooks + + @property + def total_codebooks(self) -> int: + return max(self.possible_num_codebooks) + + def set_num_codebooks(self, n: int): + """Set the active number of codebooks used by the quantizer. + """ + if n not in self.possible_num_codebooks: + raise ValueError(f"Allowed values for num codebooks: {self.possible_num_codebooks}") + self._num_codebooks = n + + +class InterleaveStereoCompressionModel(CompressionModel): + """Wraps a CompressionModel to support stereo inputs. The wrapped model + will be applied independently to the left and right channels, and both codebooks + will be interleaved. If the wrapped model returns a representation `[B, K ,T]` per + channel, then the output will be `[B, K * 2, T]` or `[B, K, T * 2]` depending on + `per_timestep`. Args: - model (CompressionModel): compression model to wrap. - codebooks_per_step (int): number of codebooks to keep per step, - this must divide the number of codebooks provided by the wrapped model. - extend_cardinality (bool): if True, and for instance if codebooks_per_step = 1, - if each codebook has a cardinality N, then the first codebook will - use the range [0, N - 1], and the second [N, 2 N - 1] etc. - On decoding, this can lead to potentially invalid sequences. - Any invalid entry will be silently remapped to the proper range - with a modulo. + model (CompressionModel): Compression model to wrap. + per_timestep (bool): Whether to interleave on the timestep dimension + or on the codebooks dimension. """ - def __init__(self, model: CompressionModel, codebooks_per_step: int = 1, - extend_cardinality: bool = True): + def __init__(self, model: CompressionModel, per_timestep: bool = False): super().__init__() self.model = model - self.codebooks_per_step = codebooks_per_step - self.extend_cardinality = extend_cardinality + self.per_timestep = per_timestep + assert self.model.channels == 1, "Wrapped model is expected to be for monophonic audio" @property def total_codebooks(self): @@ -236,30 +420,27 @@ class FlattenedCompressionModel(CompressionModel): def num_codebooks(self): """Active number of codebooks used by the quantizer. - ..Warning:: this reports the number of codebooks after the flattening + ..Warning:: this reports the number of codebooks after the interleaving of the codebooks! """ - assert self.model.num_codebooks % self.codebooks_per_step == 0 - return self.codebooks_per_step + return self.model.num_codebooks if self.per_timestep else self.model.num_codebooks * 2 def set_num_codebooks(self, n: int): """Set the active number of codebooks used by the quantizer. - ..Warning:: this sets the number of codebooks **before** the flattening - of the codebooks. + ..Warning:: this sets the number of codebooks before the interleaving! """ - assert n % self.codebooks_per_step == 0 self.model.set_num_codebooks(n) @property - def num_virtual_steps(self) -> int: + def num_virtual_steps(self) -> float: """Return the number of virtual steps, e.g. one real step will be split into that many steps. """ - return self.model.num_codebooks // self.codebooks_per_step + return 2 if self.per_timestep else 1 @property - def frame_rate(self) -> int: + def frame_rate(self) -> float: return self.model.frame_rate * self.num_virtual_steps @property @@ -268,35 +449,58 @@ class FlattenedCompressionModel(CompressionModel): @property def channels(self) -> int: - return self.model.channels + return 2 @property def cardinality(self): """Cardinality of each codebook. """ - if self.extend_cardinality: - return self.model.cardinality * self.num_virtual_steps - else: - return self.model.cardinality + return self.model.cardinality def forward(self, x: torch.Tensor) -> qt.QuantizedResult: raise NotImplementedError("Not supported, use encode and decode.") def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: - indices, scales = self.model.encode(x) - B, K, T = indices.shape - indices = rearrange(indices, 'b (k v) t -> b k t v', k=self.codebooks_per_step) - if self.extend_cardinality: - for virtual_step in range(1, self.num_virtual_steps): - indices[..., virtual_step] += self.model.cardinality * virtual_step - indices = rearrange(indices, 'b k t v -> b k (t v)') + B, C, T = x.shape + assert C == self.channels, f"Expecting stereo audio but audio num channels is {C}" + + indices_c0, scales_c0 = self.model.encode(x[:, 0, ...].unsqueeze(1)) + indices_c1, scales_c1 = self.model.encode(x[:, 1, ...].unsqueeze(1)) + indices = torch.stack([indices_c0, indices_c1], dim=0) + scales: tp.Optional[torch.Tensor] = None + if scales_c0 is not None and scales_c1 is not None: + scales = torch.stack([scales_c0, scales_c1], dim=1) + + if self.per_timestep: + indices = rearrange(indices, 'c b k t -> b k (t c)', c=2) + else: + indices = rearrange(indices, 'c b k t -> b (k c) t', c=2) + return (indices, scales) + def get_left_right_codes(self, codes: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]: + if self.per_timestep: + codes = rearrange(codes, 'b k (t c) -> c b k t', c=2) + else: + codes = rearrange(codes, 'b (k c) t -> c b k t', c=2) + return codes[0], codes[1] + def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): B, K, T = codes.shape - assert T % self.num_virtual_steps == 0 - codes = rearrange(codes, 'b k (t v) -> b (k v) t', v=self.num_virtual_steps) - # We silently ignore potential errors from the LM when - # using extend_cardinality. - codes = codes % self.model.cardinality - return self.model.decode(codes, scale) + assert T % self.num_virtual_steps == 0, "Provided codes' number of timesteps does not match" + assert K == self.num_codebooks, "Provided codes' number of codebooks does not match" + + scale_c0, scale_c1 = None, None + if scale is not None: + assert scale.size(0) == B and scale.size(1) == 2, f"Scale has unexpected shape: {scale.shape}" + scale_c0 = scale[0, ...] + scale_c1 = scale[1, ...] + + codes_c0, codes_c1 = self.get_left_right_codes(codes) + audio_c0 = self.model.decode(codes_c0, scale_c0) + audio_c1 = self.model.decode(codes_c1, scale_c1) + return torch.cat([audio_c0, audio_c1], dim=1) + + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" + raise NotImplementedError("Not supported by interleaved stereo wrapped models.") diff --git a/audiocraft/models/lm.py b/audiocraft/models/lm.py index c8aad8f06797eef3293605056e1de14d07c56c2a..c4ea2e5e800128c78226aed887fde46930adc817 100644 --- a/audiocraft/models/lm.py +++ b/audiocraft/models/lm.py @@ -41,7 +41,7 @@ def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None method (str): Method name for init function. Valid options are: 'gaussian', 'uniform'. input_dim (int): Input dimension of the initialized module. - init_depth (Optional[int]): Optional init depth value used to rescale + init_depth (int, optional): Optional init depth value used to rescale the standard deviation if defined. """ # Compute std @@ -70,7 +70,7 @@ def init_layer(m: nn.Module, Args: m (nn.Module): Module to initialize. method (str): Method name for the init function. - init_depth (Optional[int]): Optional init depth value used to rescale + init_depth (int, optional): Optional init depth value used to rescale the standard deviation if defined. zero_bias_init (bool): Whether to initialize the bias to 0 or not. """ @@ -130,10 +130,10 @@ class LMModel(StreamingModule): hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder. norm (str): Normalization method. norm_first (bool): Use pre-norm instead of post-norm. - emb_lr (Optional[float]): Embedding-specific learning rate. + emb_lr (float, optional): Embedding-specific learning rate. bias_proj (bool): Use bias for output projections. - weight_init (Optional[str]): Method for weight initialization. - depthwise_init (Optional[str]): Method for depthwise weight initialization. + weight_init (str, optional): Method for weight initialization. + depthwise_init (str, optional): Method for depthwise weight initialization. zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros. cfg_dropout (float): Classifier-free guidance dropout. cfg_coef (float): Classifier-free guidance coefficient. @@ -179,11 +179,11 @@ class LMModel(StreamingModule): """Initialization of the transformer module weights. Args: - weight_init (Optional[str]): Weight initialization strategy. See ``get_init_fn`` for valid options. - depthwise_init (Optional[str]): Depwthwise initialization strategy. The following options are valid: + weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options. + depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid: 'current' where the depth corresponds to the current layer index or 'global' where the total number of layer is used as depth. If not set, no depthwise initialization strategy is used. - zero_bias_init (bool): Whether to initalize bias to zero or not. + zero_bias_init (bool): Whether to initialize bias to zero or not. """ assert depthwise_init is None or depthwise_init in ['current', 'global'] assert depthwise_init is None or weight_init is not None, \ @@ -225,17 +225,17 @@ class LMModel(StreamingModule): S the sequence steps, return the logits with shape [B, card, K, S]. Args: - indices (torch.Tensor): indices of the codes to model. - conditions (list[ConditioningAttributes]): conditionings to use when modeling + indices (torch.Tensor): Indices of the codes to model. + conditions (list of ConditioningAttributes): Conditions to use when modeling the given codes. Note that when evaluating multiple time with the same conditioning you should pre-compute those and pass them as `condition_tensors`. - condition_tensors (dict[str, ConditionType] or None): pre-computed conditioning + condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning tensors, see `conditions`. Returns: torch.Tensor: Logits. """ B, K, S = sequence.shape - assert K == self.num_codebooks, 'Sequence shape must match the specified number of codebooks' + assert K == self.num_codebooks, "Sequence shape must match the specified number of codebooks" input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)]) if condition_tensors is None: assert not self._is_streaming, "Conditions tensors should be precomputed when streaming." @@ -271,10 +271,10 @@ class LMModel(StreamingModule): Args: codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size, K the number of codebooks and T the number of timesteps. - conditions (list[ConditioningAttributes]): conditionings to use when modeling + conditions (list of ConditioningAttributes): conditionings to use when modeling the given codes. Note that when evaluating multiple time with the same conditioning you should pre-compute those and pass them as `condition_tensors`. - condition_tensors (dict[str, ConditionType] or None): pre-computed conditioning + condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning tensors, see `conditions`. Returns: LMOutput: Language model outputs @@ -314,7 +314,8 @@ class LMModel(StreamingModule): temp: float = 1.0, top_k: int = 0, top_p: float = 0.0, - cfg_coef: tp.Optional[float] = None) -> torch.Tensor: + cfg_coef: tp.Optional[float] = None, + two_step_cfg: tp.Optional[bool] = None) -> torch.Tensor: """Sample next token from the model given a sequence and a set of conditions. The model supports multiple sampling strategies (greedy sampling, softmax, top-k, top-p...). @@ -322,21 +323,22 @@ class LMModel(StreamingModule): sequence (torch.Tensor): Current sequence of shape [B, K, S] with K corresponding to the number of codebooks and S the number of sequence steps. S = 1 in streaming mode, except for the first step that contains a bigger prompt. - condition_tensors (Dict[str, ConditionType): Set of conditions. If CFG is used, + condition_tensors (dict[str, ConditionType): Set of conditions. If CFG is used, should be twice the batch size, being the concatenation of the conditions + null conditions. use_sampling (bool): Whether to use a sampling strategy or not. temp (float): Sampling temperature. top_k (int): K for "top-k" sampling. top_p (float): P for "top-p" sampling. - cfg_coef (float): classifier free guidance coefficient + cfg_coef (float, optional): classifier free guidance coefficient Returns: next_token (torch.Tensor): Next token tensor of shape [B, K, 1]. """ B = sequence.shape[0] cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef model = self if self._fsdp is None else self._fsdp - if self.two_step_cfg and cfg_conditions != {}: - assert isinstance(cfg_conditions, tuple) + two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg + if two_step_cfg and cfg_conditions != {}: + assert isinstance(cfg_conditions, tuple), type(cfg_conditions) condition_tensors, null_condition_tensors = cfg_conditions cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors) state = self.get_streaming_state() @@ -388,7 +390,7 @@ class LMModel(StreamingModule): top_k: int = 250, top_p: float = 0.0, cfg_coef: tp.Optional[float] = None, - two_step_cfg: bool = False, + two_step_cfg: tp.Optional[bool] = None, remove_prompts: bool = False, check: bool = False, callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor: @@ -396,15 +398,19 @@ class LMModel(StreamingModule): be perform in a greedy fashion or using sampling with top K and top P strategies. Args: - prompt (Optional[torch.Tensor]): Prompt tokens of shape [B, K, T]. - conditions_tensors (Dict[str, torch.Tensor]): Set of conditions or None. - num_samples (int or None): Number of samples to generate when no prompt and no conditions are given. + prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T]. + conditions_tensors (list of ConditioningAttributes, optional): List of conditions. + num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given. max_gen_len (int): Maximum generation length. use_sampling (bool): Whether to use a sampling strategy or not. temp (float): Sampling temperature. top_k (int): K for "top-k" sampling. top_p (float): P for "top-p" sampling. + cfg_coeff (float, optional): Classifier-free guidance coefficient. + two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation. remove_prompts (bool): Whether to remove prompts from generation or not. + check (bool): Whether to apply further checks on generated sequence. + callback (Callback, optional): Callback function to report generation progress. Returns: torch.Tensor: Generated tokens. """ @@ -412,7 +418,7 @@ class LMModel(StreamingModule): first_param = next(iter(self.parameters())) device = first_param.device - # Checking all input shapes are consistents. + # Checking all input shapes are consistent. possible_num_samples = [] if num_samples is not None: possible_num_samples.append(num_samples) @@ -422,7 +428,7 @@ class LMModel(StreamingModule): possible_num_samples.append(len(conditions)) else: possible_num_samples.append(1) - assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsitent inputs shapes" + assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes" num_samples = possible_num_samples[0] # below we create set of conditions: one conditional and one unconditional @@ -432,7 +438,7 @@ class LMModel(StreamingModule): # 1. it is about x2 faster than doing 2 forward passes # 2. avoid the streaming API treating the 2 passes as part of different time steps # We also support doing two different passes, in particular to ensure that - # the padding structure is exactly the same between train anf test. + # the padding structure is exactly the same between train and test. # With a batch size of 1, this can be slower though. cfg_conditions: CFGConditions two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg @@ -489,7 +495,7 @@ class LMModel(StreamingModule): # sample next token from the model, next token shape is [B, K, 1] next_token = self._sample_next_token( curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p, - cfg_coef=cfg_coef) + cfg_coef=cfg_coef, two_step_cfg=two_step_cfg) # ensure the tokens that should be masked are properly set to special_token_id # as the model never output special_token_id valid_mask = mask[..., offset:offset+1].expand(B, -1, -1) diff --git a/audiocraft/models/loaders.py b/audiocraft/models/loaders.py index 97c662c3212b7695669cbfc5214ff2f099c3f319..f02ba115353a22c43926642e4dcc00376a4ada7e 100644 --- a/audiocraft/models/loaders.py +++ b/audiocraft/models/loaders.py @@ -24,18 +24,16 @@ from huggingface_hub import hf_hub_download import typing as tp import os -from omegaconf import OmegaConf +from omegaconf import OmegaConf, DictConfig import torch +import audiocraft from . import builders +from .encodec import CompressionModel -HF_MODEL_CHECKPOINTS_MAP = { - "small": "facebook/musicgen-small", - "medium": "facebook/musicgen-medium", - "large": "facebook/musicgen-large", - "melody": "facebook/musicgen-melody", -} +def get_audiocraft_cache_dir() -> tp.Optional[str]: + return os.environ.get('AUDIOCRAFT_CACHE_DIR', None) def _get_state_dict( @@ -44,6 +42,8 @@ def _get_state_dict( device='cpu', cache_dir: tp.Optional[str] = None, ): + if cache_dir is None: + cache_dir = get_audiocraft_cache_dir() # Return the state dict either from a file or url file_or_url_or_id = str(file_or_url_or_id) assert isinstance(file_or_url_or_id, str) @@ -58,19 +58,23 @@ def _get_state_dict( elif file_or_url_or_id.startswith('https://'): return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True) - elif file_or_url_or_id in HF_MODEL_CHECKPOINTS_MAP: + else: assert filename is not None, "filename needs to be defined if using HF checkpoints" - repo_id = HF_MODEL_CHECKPOINTS_MAP[file_or_url_or_id] - file = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir) + file = hf_hub_download( + repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir, + library_name="audiocraft", library_version=audiocraft.__version__) return torch.load(file, map_location=device) - else: - raise ValueError(f"{file_or_url_or_id} is not a valid name, path or link that can be loaded.") + +def load_compression_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None): + return _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir) def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None): - pkg = _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir) + pkg = load_compression_model_ckpt(file_or_url_or_id, cache_dir=cache_dir) + if 'pretrained' in pkg: + return CompressionModel.get_pretrained(pkg['pretrained'], device=device) cfg = OmegaConf.create(pkg['xp.cfg']) cfg.device = str(device) model = builders.get_compression_model(cfg) @@ -79,16 +83,67 @@ def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', return model +def load_lm_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None): + return _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir) + + +def _delete_param(cfg: DictConfig, full_name: str): + parts = full_name.split('.') + for part in parts[:-1]: + if part in cfg: + cfg = cfg[part] + else: + return + OmegaConf.set_struct(cfg, False) + if parts[-1] in cfg: + del cfg[parts[-1]] + OmegaConf.set_struct(cfg, True) + + def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None): - pkg = _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir) + pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir) cfg = OmegaConf.create(pkg['xp.cfg']) cfg.device = str(device) if cfg.device == 'cpu': cfg.dtype = 'float32' else: cfg.dtype = 'float16' + _delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path') + _delete_param(cfg, 'conditioners.args.merge_text_conditions_p') + _delete_param(cfg, 'conditioners.args.drop_desc_p') model = builders.get_lm_model(cfg) model.load_state_dict(pkg['best_state']) model.eval() model.cfg = cfg return model + + +def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str], + filename: tp.Optional[str] = None, + cache_dir: tp.Optional[str] = None): + return _get_state_dict(file_or_url_or_id, filename=filename, cache_dir=cache_dir) + + +def load_diffusion_models(file_or_url_or_id: tp.Union[Path, str], + device='cpu', + filename: tp.Optional[str] = None, + cache_dir: tp.Optional[str] = None): + pkg = load_mbd_ckpt(file_or_url_or_id, filename=filename, cache_dir=cache_dir) + models = [] + processors = [] + cfgs = [] + sample_rate = pkg['sample_rate'] + for i in range(pkg['n_bands']): + cfg = pkg[i]['cfg'] + model = builders.get_diffusion_model(cfg) + model_dict = pkg[i]['model_state'] + model.load_state_dict(model_dict) + model.to(device) + processor = builders.get_processor(cfg=cfg.processor, sample_rate=sample_rate) + processor_dict = pkg[i]['processor_state'] + processor.load_state_dict(processor_dict) + processor.to(device) + models.append(model) + processors.append(processor) + cfgs.append(cfg) + return models, processors, cfgs diff --git a/audiocraft/models/multibanddiffusion.py b/audiocraft/models/multibanddiffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..f312fa669e38e6159976486a23fb0f6ef47f58fa --- /dev/null +++ b/audiocraft/models/multibanddiffusion.py @@ -0,0 +1,196 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Multi Band Diffusion models as described in +"From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion" +(paper link). +""" + +import typing as tp + +import torch +import julius + +from .unet import DiffusionUnet +from ..modules.diffusion_schedule import NoiseSchedule +from .encodec import CompressionModel +from ..solvers.compression import CompressionSolver +from .loaders import load_compression_model, load_diffusion_models + + +class DiffusionProcess: + """Sampling for a diffusion Model. + + Args: + model (DiffusionUnet): Diffusion U-Net model. + noise_schedule (NoiseSchedule): Noise schedule for diffusion process. + """ + def __init__(self, model: DiffusionUnet, noise_schedule: NoiseSchedule) -> None: + """ + """ + self.model = model + self.schedule = noise_schedule + + def generate(self, condition: torch.Tensor, initial_noise: torch.Tensor, + step_list: tp.Optional[tp.List[int]] = None): + """Perform one diffusion process to generate one of the bands. + + Args: + condition (tensor): The embeddings form the compression model. + initial_noise (tensor): The initial noise to start the process/ + """ + return self.schedule.generate_subsampled(model=self.model, initial=initial_noise, step_list=step_list, + condition=condition) + + +class MultiBandDiffusion: + """Sample from multiple diffusion models. + + Args: + DPs (list of DiffusionProcess): Diffusion processes. + codec_model (CompressionModel): Underlying compression model used to obtain discrete tokens. + """ + def __init__(self, DPs: tp.List[DiffusionProcess], codec_model: CompressionModel) -> None: + self.DPs = DPs + self.codec_model = codec_model + self.device = next(self.codec_model.parameters()).device + + @property + def sample_rate(self) -> int: + return self.codec_model.sample_rate + + @staticmethod + def get_mbd_musicgen(device=None): + """Load our diffusion models trained for MusicGen.""" + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + path = 'facebook/multiband-diffusion' + filename = 'mbd_musicgen_32khz.th' + name = 'facebook/musicgen-small' + codec_model = load_compression_model(name, device=device) + models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device) + DPs = [] + for i in range(len(models)): + schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device) + DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule)) + return MultiBandDiffusion(DPs=DPs, codec_model=codec_model) + + @staticmethod + def get_mbd_24khz(bw: float = 3.0, pretrained: bool = True, + device: tp.Optional[tp.Union[torch.device, str]] = None, + n_q: tp.Optional[int] = None): + """Get the pretrained Models for MultibandDiffusion. + + Args: + bw (float): Bandwidth of the compression model. + pretrained (bool): Whether to use / download if necessary the models. + device (torch.device or str, optional): Device on which the models are loaded. + n_q (int, optional): Number of quantizers to use within the compression model. + """ + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + assert bw in [1.5, 3.0, 6.0], f"bandwidth {bw} not available" + if n_q is not None: + assert n_q in [2, 4, 8] + assert {1.5: 2, 3.0: 4, 6.0: 8}[bw] == n_q, \ + f"bandwidth and number of codebooks missmatch to use n_q = {n_q} bw should be {n_q * (1.5 / 2)}" + n_q = {1.5: 2, 3.0: 4, 6.0: 8}[bw] + codec_model = CompressionSolver.model_from_checkpoint( + '//pretrained/facebook/encodec_24khz', device=device) + codec_model.set_num_codebooks(n_q) + codec_model = codec_model.to(device) + path = 'facebook/multiband-diffusion' + filename = f'mbd_comp_{n_q}.pt' + models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device) + DPs = [] + for i in range(len(models)): + schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device) + DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule)) + return MultiBandDiffusion(DPs=DPs, codec_model=codec_model) + + return MultiBandDiffusion(DPs, codec_model) + + @torch.no_grad() + def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor: + """Get the conditioning (i.e. latent reprentatios of the compression model) from a waveform. + Args: + wav (torch.Tensor): The audio that we want to extract the conditioning from + sample_rate (int): sample rate of the audio""" + if sample_rate != self.sample_rate: + wav = julius.resample_frac(wav, sample_rate, self.sample_rate) + codes, scale = self.codec_model.encode(wav) + assert scale is None, "Scaled compression models not supported." + emb = self.get_emb(codes) + return emb + + @torch.no_grad() + def get_emb(self, codes: torch.Tensor): + """Get latent representation from the discrete codes + Argrs: + codes (torch.Tensor): discrete tokens""" + emb = self.codec_model.decode_latent(codes) + return emb + + def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = None, + step_list: tp.Optional[tp.List[int]] = None): + """Generate Wavform audio from the latent embeddings of the compression model + Args: + emb (torch.Tensor): Conditioning embeddinds + size (none torch.Size): size of the output + if None this is computed from the typical upsampling of the model + step_list (optional list[int]): list of Markov chain steps, defaults to 50 linearly spaced step. + """ + if size is None: + upsampling = int(self.codec_model.sample_rate / self.codec_model.frame_rate) + size = torch.Size([emb.size(0), self.codec_model.channels, emb.size(-1) * upsampling]) + assert size[0] == emb.size(0) + out = torch.zeros(size).to(self.device) + for DP in self.DPs: + out += DP.generate(condition=emb, step_list=step_list, initial_noise=torch.randn_like(out)) + return out + + def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 32, strictness: float = 1): + """match the eq to the encodec output by matching the standard deviation of some frequency bands + Args: + wav (torch.Tensor): audio to equalize + ref (torch.Tensor):refenrence audio from which we match the spectrogram. + n_bands (int): number of bands of the eq + strictness (float): how strict the the matching. 0 is no matching, 1 is exact matching. + """ + split = julius.SplitBands(n_bands=n_bands, sample_rate=self.codec_model.sample_rate).to(wav.device) + bands = split(wav) + bands_ref = split(ref) + out = torch.zeros_like(ref) + for i in range(n_bands): + out += bands[i] * (bands_ref[i].std() / bands[i].std()) ** strictness + return out + + def regenerate(self, wav: torch.Tensor, sample_rate: int): + """Regenerate a wavform through compression and diffusion regeneration. + Args: + wav (torch.Tensor): Original 'ground truth' audio + sample_rate (int): sample rate of the input (and output) wav + """ + if sample_rate != self.codec_model.sample_rate: + wav = julius.resample_frac(wav, sample_rate, self.codec_model.sample_rate) + emb = self.get_condition(wav, sample_rate=self.codec_model.sample_rate) + size = wav.size() + out = self.generate(emb, size=size) + if sample_rate != self.codec_model.sample_rate: + out = julius.resample_frac(out, self.codec_model.sample_rate, sample_rate) + return out + + def tokens_to_wav(self, tokens: torch.Tensor, n_bands: int = 32): + """Generate Waveform audio with diffusion from the discrete codes. + Args: + tokens (torch.Tensor): discrete codes + n_bands (int): bands for the eq matching. + """ + wav_encodec = self.codec_model.decode(tokens) + condition = self.get_emb(tokens) + wav_diffusion = self.generate(emb=condition, size=wav_encodec.size()) + return self.re_eq(wav=wav_diffusion, ref=wav_encodec, n_bands=n_bands) diff --git a/audiocraft/models/musicgen.py b/audiocraft/models/musicgen.py index 007dd9e0ed1cfd359fb4889e7f4108248e189941..88ee13b6a5da2a54e580db7c39accb1acbade6b4 100644 --- a/audiocraft/models/musicgen.py +++ b/audiocraft/models/musicgen.py @@ -9,15 +9,16 @@ Main model for using MusicGen. This will combine all the required components and provide easy access to the generation API. """ -import os import typing as tp +import warnings +import omegaconf import torch from .encodec import CompressionModel from .lm import LMModel -from .builders import get_debug_compression_model, get_debug_lm_model -from .loaders import load_compression_model, load_lm_model, HF_MODEL_CHECKPOINTS_MAP +from .builders import get_debug_compression_model, get_debug_lm_model, get_wrapped_compression_model +from .loaders import load_compression_model, load_lm_model from ..data.audio_utils import convert_audio from ..modules.conditioners import ConditioningAttributes, WavCondition from ..utils.autocast import TorchAutocast @@ -27,6 +28,15 @@ MelodyList = tp.List[tp.Optional[torch.Tensor]] MelodyType = tp.Union[torch.Tensor, MelodyList] +# backward compatible names mapping +_HF_MODEL_CHECKPOINTS_MAP = { + "small": "facebook/musicgen-small", + "medium": "facebook/musicgen-medium", + "large": "facebook/musicgen-large", + "melody": "facebook/musicgen-melody", +} + + class MusicGen: """MusicGen main model with convenient generation API. @@ -35,14 +45,36 @@ class MusicGen: compression_model (CompressionModel): Compression model used to map audio to invertible discrete representations. lm (LMModel): Language model over discrete representations. + max_duration (float, optional): maximum duration the model can produce, + otherwise, inferred from the training params. """ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel, - max_duration: float = 30): + max_duration: tp.Optional[float] = None): self.name = name self.compression_model = compression_model self.lm = lm - self.max_duration = max_duration + self.cfg: tp.Optional[omegaconf.DictConfig] = None + # Just to be safe, let's put everything in eval mode. + self.compression_model.eval() + self.lm.eval() + + if hasattr(lm, 'cfg'): + cfg = lm.cfg + assert isinstance(cfg, omegaconf.DictConfig) + self.cfg = cfg + + if self.cfg is not None: + self.compression_model = get_wrapped_compression_model(self.compression_model, self.cfg) + + if max_duration is None: + if self.cfg is not None: + max_duration = lm.cfg.dataset.segment_duration # type: ignore + else: + raise ValueError("You must provide max_duration when building directly MusicGen") + assert max_duration is not None + self.max_duration: float = max_duration self.device = next(iter(lm.parameters())).device + self.generation_params: dict = {} self.set_generation_params(duration=15) # 15 seconds by default self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None @@ -53,7 +85,7 @@ class MusicGen: enabled=True, device_type=self.device.type, dtype=torch.float16) @property - def frame_rate(self) -> int: + def frame_rate(self) -> float: """Roughly the number of AR steps per seconds.""" return self.compression_model.frame_rate @@ -68,14 +100,17 @@ class MusicGen: return self.compression_model.channels @staticmethod - def get_pretrained(name: str = 'melody', device=None): + def get_pretrained(name: str = 'facebook/musicgen-melody', device=None): """Return pretrained model, we provide four models: - - small (300M), text to music, # see: https://huggingface.co./facebook/musicgen-small - - medium (1.5B), text to music, # see: https://huggingface.co./facebook/musicgen-medium - - melody (1.5B) text to music and text+melody to music, # see: https://huggingface.co./facebook/musicgen-melody - - large (3.3B), text to music, # see: https://huggingface.co./facebook/musicgen-large + - facebook/musicgen-small (300M), text to music, + # see: https://huggingface.co./facebook/musicgen-small + - facebook/musicgen-medium (1.5B), text to music, + # see: https://huggingface.co./facebook/musicgen-medium + - facebook/musicgen-melody (1.5B) text to music and text+melody to music, + # see: https://huggingface.co./facebook/musicgen-melody + - facebook/musicgen-large (3.3B), text to music, + # see: https://huggingface.co./facebook/musicgen-large """ - if device is None: if torch.cuda.device_count(): device = 'cuda' @@ -86,20 +121,19 @@ class MusicGen: # used only for unit tests compression_model = get_debug_compression_model(device) lm = get_debug_lm_model(device) - return MusicGen(name, compression_model, lm) - - if name not in HF_MODEL_CHECKPOINTS_MAP: - if not os.path.isfile(name) and not os.path.isdir(name): - raise ValueError( - f"{name} is not a valid checkpoint name. " - f"Choose one of {', '.join(HF_MODEL_CHECKPOINTS_MAP.keys())}" - ) - - cache_dir = os.environ.get('MUSICGEN_ROOT', None) - compression_model = load_compression_model(name, device=device, cache_dir=cache_dir) - lm = load_lm_model(name, device=device, cache_dir=cache_dir) - if name == 'melody': + return MusicGen(name, compression_model, lm, max_duration=30) + + if name in _HF_MODEL_CHECKPOINTS_MAP: + warnings.warn( + "MusicGen pretrained model relying on deprecated checkpoint mapping. " + + f"Please use full pre-trained id instead: facebook/musicgen-{name}") + name = _HF_MODEL_CHECKPOINTS_MAP[name] + + lm = load_lm_model(name, device=device) + compression_model = load_compression_model(name, device=device) + if 'self_wav' in lm.condition_provider.conditioners: lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True + lm.condition_provider.conditioners['self_wav']._use_masking = False return MusicGen(name, compression_model, lm) @@ -139,7 +173,9 @@ class MusicGen: """Override the default progress callback.""" self._progress_callback = progress_callback - def generate_unconditional(self, num_samples: int, progress: bool = False) -> torch.Tensor: + def generate_unconditional(self, num_samples: int, progress: bool = False, + return_tokens: bool = False) -> tp.Union[torch.Tensor, + tp.Tuple[torch.Tensor, torch.Tensor]]: """Generate samples in an unconditional manner. Args: @@ -148,25 +184,34 @@ class MusicGen: """ descriptions: tp.List[tp.Optional[str]] = [None] * num_samples attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None) - return self._generate_tokens(attributes, prompt_tokens, progress) + tokens = self._generate_tokens(attributes, prompt_tokens, progress) + if return_tokens: + return self.generate_audio(tokens), tokens + return self.generate_audio(tokens) - def generate(self, descriptions: tp.List[str], progress: bool = False) -> torch.Tensor: + def generate(self, descriptions: tp.List[str], progress: bool = False, return_tokens: bool = False) \ + -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: """Generate samples conditioned on text. Args: - descriptions (tp.List[str]): A list of strings used as text conditioning. + descriptions (list of str): A list of strings used as text conditioning. progress (bool, optional): Flag to display progress of the generation process. Defaults to False. """ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None) assert prompt_tokens is None - return self._generate_tokens(attributes, prompt_tokens, progress) + tokens = self._generate_tokens(attributes, prompt_tokens, progress) + if return_tokens: + return self.generate_audio(tokens), tokens + return self.generate_audio(tokens) def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType, - melody_sample_rate: int, progress: bool = False) -> torch.Tensor: + melody_sample_rate: int, progress: bool = False, + return_tokens: bool = False) -> tp.Union[torch.Tensor, + tp.Tuple[torch.Tensor, torch.Tensor]]: """Generate samples conditioned on text and melody. Args: - descriptions (tp.List[str]): A list of strings used as text conditioning. + descriptions (list of str): A list of strings used as text conditioning. melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as melody conditioning. Should have shape [B, C, T] with B matching the description length, C=1 or 2. It can be [C, T] if there is a single description. It can also be @@ -192,18 +237,22 @@ class MusicGen: attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None, melody_wavs=melody_wavs) assert prompt_tokens is None - return self._generate_tokens(attributes, prompt_tokens, progress) + tokens = self._generate_tokens(attributes, prompt_tokens, progress) + if return_tokens: + return self.generate_audio(tokens), tokens + return self.generate_audio(tokens) def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int, descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None, - progress: bool = False) -> torch.Tensor: + progress: bool = False, return_tokens: bool = False) \ + -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: """Generate samples conditioned on audio prompts. Args: prompt (torch.Tensor): A batch of waveforms used for continuation. Prompt should be [B, C, T], or [C, T] if only one sample is generated. prompt_sample_rate (int): Sampling rate of the given audio waveforms. - descriptions (tp.List[str], optional): A list of strings used as text conditioning. Defaults to None. + descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None. progress (bool, optional): Flag to display progress of the generation process. Defaults to False. """ if prompt.dim() == 2: @@ -215,7 +264,10 @@ class MusicGen: descriptions = [None] * len(prompt) attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt) assert prompt_tokens is not None - return self._generate_tokens(attributes, prompt_tokens, progress) + tokens = self._generate_tokens(attributes, prompt_tokens, progress) + if return_tokens: + return self.generate_audio(tokens), tokens + return self.generate_audio(tokens) @torch.no_grad() def _prepare_tokens_and_attributes( @@ -227,9 +279,9 @@ class MusicGen: """Prepare model inputs. Args: - descriptions (tp.List[str]): A list of strings used as text conditioning. + descriptions (list of str): A list of strings used as text conditioning. prompt (torch.Tensor): A batch of waveforms used for continuation. - melody_wavs (tp.Optional[torch.Tensor], optional): A batch of waveforms + melody_wavs (torch.Tensor, optional): A batch of waveforms used as melody conditioning. Defaults to None. """ attributes = [ @@ -239,11 +291,12 @@ class MusicGen: if melody_wavs is None: for attr in attributes: attr.wav['self_wav'] = WavCondition( - torch.zeros((1, 1), device=self.device), + torch.zeros((1, 1, 1), device=self.device), torch.tensor([0], device=self.device), - path='null_wav') # type: ignore + sample_rate=[self.sample_rate], + path=[None]) else: - if self.name != "melody": + if 'self_wav' not in self.lm.condition_provider.conditioners: raise RuntimeError("This model doesn't support melody conditioning. " "Use the `melody` model.") assert len(melody_wavs) == len(descriptions), \ @@ -252,13 +305,17 @@ class MusicGen: for attr, melody in zip(attributes, melody_wavs): if melody is None: attr.wav['self_wav'] = WavCondition( - torch.zeros((1, 1), device=self.device), + torch.zeros((1, 1, 1), device=self.device), torch.tensor([0], device=self.device), - path='null_wav') # type: ignore + sample_rate=[self.sample_rate], + path=[None]) else: attr.wav['self_wav'] = WavCondition( - melody.to(device=self.device), - torch.tensor([melody.shape[-1]], device=self.device)) + melody[None].to(device=self.device), + torch.tensor([melody.shape[-1]], device=self.device), + sample_rate=[self.sample_rate], + path=[None], + ) if prompt is not None: if descriptions is not None: @@ -275,8 +332,8 @@ class MusicGen: """Generate discrete audio tokens given audio prompt and/or conditions. Args: - attributes (tp.List[ConditioningAttributes]): Conditions used for generation (text/melody). - prompt_tokens (tp.Optional[torch.Tensor]): Audio prompt used for continuation. + attributes (list of ConditioningAttributes): Conditions used for generation (text/melody). + prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation. progress (bool, optional): Flag to display progress of the generation process. Defaults to False. Returns: torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params. @@ -335,12 +392,13 @@ class MusicGen: # we wouldn't have the full wav. initial_position = int(time_offset * self.sample_rate) wav_target_length = int(self.max_duration * self.sample_rate) - print(initial_position / self.sample_rate, wav_target_length / self.sample_rate) positions = torch.arange(initial_position, initial_position + wav_target_length, device=self.device) attr.wav['self_wav'] = WavCondition( - ref_wav[0][:, positions % wav_length], - torch.full_like(ref_wav[1], wav_target_length)) + ref_wav[0][..., positions % wav_length], + torch.full_like(ref_wav[1], wav_target_length), + [self.sample_rate] * ref_wav[0].size(0), + [None], [0.]) with self.autocast: gen_tokens = self.lm.generate( prompt_tokens, attributes, @@ -354,8 +412,10 @@ class MusicGen: current_gen_offset += stride_tokens gen_tokens = torch.cat(all_tokens, dim=-1) + return gen_tokens - # generate audio + def generate_audio(self, gen_tokens: torch.Tensor): + """Generate Audio from tokens""" assert gen_tokens.dim() == 3 with torch.no_grad(): gen_audio = self.compression_model.decode(gen_tokens, None) diff --git a/audiocraft/models/unet.py b/audiocraft/models/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..db4a6df8e309c21fede37abdbe3c862932027641 --- /dev/null +++ b/audiocraft/models/unet.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Pytorch Unet Module used for diffusion. +""" + +from dataclasses import dataclass +import typing as tp + +import torch +from torch import nn +from torch.nn import functional as F +from audiocraft.modules.transformer import StreamingTransformer, create_sin_embedding + + +@dataclass +class Output: + sample: torch.Tensor + + +def get_model(cfg, channels: int, side: int, num_steps: int): + if cfg.model == 'unet': + return DiffusionUnet( + chin=channels, num_steps=num_steps, **cfg.diffusion_unet) + else: + raise RuntimeError('Not Implemented') + + +class ResBlock(nn.Module): + def __init__(self, channels: int, kernel: int = 3, norm_groups: int = 4, + dilation: int = 1, activation: tp.Type[nn.Module] = nn.ReLU, + dropout: float = 0.): + super().__init__() + stride = 1 + padding = dilation * (kernel - stride) // 2 + Conv = nn.Conv1d + Drop = nn.Dropout1d + self.norm1 = nn.GroupNorm(norm_groups, channels) + self.conv1 = Conv(channels, channels, kernel, 1, padding, dilation=dilation) + self.activation1 = activation() + self.dropout1 = Drop(dropout) + + self.norm2 = nn.GroupNorm(norm_groups, channels) + self.conv2 = Conv(channels, channels, kernel, 1, padding, dilation=dilation) + self.activation2 = activation() + self.dropout2 = Drop(dropout) + + def forward(self, x): + h = self.dropout1(self.conv1(self.activation1(self.norm1(x)))) + h = self.dropout2(self.conv2(self.activation2(self.norm2(h)))) + return x + h + + +class DecoderLayer(nn.Module): + def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2, + norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU, + dropout: float = 0.): + super().__init__() + padding = (kernel - stride) // 2 + self.res_blocks = nn.Sequential( + *[ResBlock(chin, norm_groups=norm_groups, dilation=2**idx, dropout=dropout) + for idx in range(res_blocks)]) + self.norm = nn.GroupNorm(norm_groups, chin) + ConvTr = nn.ConvTranspose1d + self.convtr = ConvTr(chin, chout, kernel, stride, padding, bias=False) + self.activation = activation() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.res_blocks(x) + x = self.norm(x) + x = self.activation(x) + x = self.convtr(x) + return x + + +class EncoderLayer(nn.Module): + def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2, + norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU, + dropout: float = 0.): + super().__init__() + padding = (kernel - stride) // 2 + Conv = nn.Conv1d + self.conv = Conv(chin, chout, kernel, stride, padding, bias=False) + self.norm = nn.GroupNorm(norm_groups, chout) + self.activation = activation() + self.res_blocks = nn.Sequential( + *[ResBlock(chout, norm_groups=norm_groups, dilation=2**idx, dropout=dropout) + for idx in range(res_blocks)]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, C, T = x.shape + stride, = self.conv.stride + pad = (stride - (T % stride)) % stride + x = F.pad(x, (0, pad)) + + x = self.conv(x) + x = self.norm(x) + x = self.activation(x) + x = self.res_blocks(x) + return x + + +class BLSTM(nn.Module): + """BiLSTM with same hidden units as input dim. + """ + def __init__(self, dim, layers=2): + super().__init__() + self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) + self.linear = nn.Linear(2 * dim, dim) + + def forward(self, x): + x = x.permute(2, 0, 1) + x = self.lstm(x)[0] + x = self.linear(x) + x = x.permute(1, 2, 0) + return x + + +class DiffusionUnet(nn.Module): + def __init__(self, chin: int = 3, hidden: int = 24, depth: int = 3, growth: float = 2., + max_channels: int = 10_000, num_steps: int = 1000, emb_all_layers=False, cross_attention: bool = False, + bilstm: bool = False, transformer: bool = False, + codec_dim: tp.Optional[int] = None, **kwargs): + super().__init__() + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + self.embeddings: tp.Optional[nn.ModuleList] = None + self.embedding = nn.Embedding(num_steps, hidden) + if emb_all_layers: + self.embeddings = nn.ModuleList() + self.condition_embedding: tp.Optional[nn.Module] = None + for d in range(depth): + encoder = EncoderLayer(chin, hidden, **kwargs) + decoder = DecoderLayer(hidden, chin, **kwargs) + self.encoders.append(encoder) + self.decoders.insert(0, decoder) + if emb_all_layers and d > 0: + assert self.embeddings is not None + self.embeddings.append(nn.Embedding(num_steps, hidden)) + chin = hidden + hidden = min(int(chin * growth), max_channels) + self.bilstm: tp.Optional[nn.Module] + if bilstm: + self.bilstm = BLSTM(chin) + else: + self.bilstm = None + self.use_transformer = transformer + self.cross_attention = False + if transformer: + self.cross_attention = cross_attention + self.transformer = StreamingTransformer(chin, 8, 6, bias_ff=False, bias_attn=False, + cross_attention=cross_attention) + + self.use_codec = False + if codec_dim is not None: + self.conv_codec = nn.Conv1d(codec_dim, chin, 1) + self.use_codec = True + + def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], condition: tp.Optional[torch.Tensor] = None): + skips = [] + bs = x.size(0) + z = x + view_args = [1] + if type(step) is torch.Tensor: + step_tensor = step + else: + step_tensor = torch.tensor([step], device=x.device, dtype=torch.long).expand(bs) + + for idx, encoder in enumerate(self.encoders): + z = encoder(z) + if idx == 0: + z = z + self.embedding(step_tensor).view(bs, -1, *view_args).expand_as(z) + elif self.embeddings is not None: + z = z + self.embeddings[idx - 1](step_tensor).view(bs, -1, *view_args).expand_as(z) + + skips.append(z) + + if self.use_codec: # insert condition in the bottleneck + assert condition is not None, "Model defined for conditionnal generation" + condition_emb = self.conv_codec(condition) # reshape to the bottleneck dim + assert condition_emb.size(-1) <= 2 * z.size(-1), \ + f"You are downsampling the conditionning with factor >=2 : {condition_emb.size(-1)=} and {z.size(-1)=}" + if not self.cross_attention: + + condition_emb = torch.nn.functional.interpolate(condition_emb, z.size(-1)) + assert z.size() == condition_emb.size() + z += condition_emb + cross_attention_src = None + else: + cross_attention_src = condition_emb.permute(0, 2, 1) # B, T, C + B, T, C = cross_attention_src.shape + positions = torch.arange(T, device=x.device).view(1, -1, 1) + pos_emb = create_sin_embedding(positions, C, max_period=10_000, dtype=cross_attention_src.dtype) + cross_attention_src = cross_attention_src + pos_emb + if self.use_transformer: + z = self.transformer(z.permute(0, 2, 1), cross_attention_src=cross_attention_src).permute(0, 2, 1) + else: + if self.bilstm is None: + z = torch.zeros_like(z) + else: + z = self.bilstm(z) + + for decoder in self.decoders: + s = skips.pop(-1) + z = z[:, :, :s.shape[2]] + z = z + s + z = decoder(z) + + z = z[:, :, :x.shape[2]] + return Output(z) diff --git a/audiocraft/modules/__init__.py b/audiocraft/modules/__init__.py index 81ba30f6466ff91b90490a4fb92f7d3d0d00144d..61418616ef18f0ecca56a007c43af4a731d98b9b 100644 --- a/audiocraft/modules/__init__.py +++ b/audiocraft/modules/__init__.py @@ -3,6 +3,7 @@ # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +"""Modules used for building the models.""" # flake8: noqa from .conv import ( @@ -18,3 +19,4 @@ from .conv import ( ) from .lstm import StreamableLSTM from .seanet import SEANetEncoder, SEANetDecoder +from .transformer import StreamingTransformer \ No newline at end of file diff --git a/audiocraft/modules/activations.py b/audiocraft/modules/activations.py index 8bd6f2917a56d72db56555d0ff54b2311bc21778..2d83d7c4c2dc84c64b724eadbe06157507d4f20d 100644 --- a/audiocraft/modules/activations.py +++ b/audiocraft/modules/activations.py @@ -84,7 +84,7 @@ def get_activation_fn( If the supplied activation is not a string that is recognized, the activation is passed back. Args: - activation (Union[str, Callable[[Tensor], Tensor]]): Activation to check + activation (str, or Callable[[Tensor], Tensor]): Activation to check """ if isinstance(activation, str): if activation == "reglu": diff --git a/audiocraft/modules/chroma.py b/audiocraft/modules/chroma.py new file mode 100644 index 0000000000000000000000000000000000000000..e84fb66b4a4aaefb0b3ccac8a9a44c3b20e48f61 --- /dev/null +++ b/audiocraft/modules/chroma.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import typing as tp + +from einops import rearrange +from librosa import filters +import torch +from torch import nn +import torch.nn.functional as F +import torchaudio + + +class ChromaExtractor(nn.Module): + """Chroma extraction and quantization. + + Args: + sample_rate (int): Sample rate for the chroma extraction. + n_chroma (int): Number of chroma bins for the chroma extraction. + radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12). + nfft (int, optional): Number of FFT. + winlen (int, optional): Window length. + winhop (int, optional): Window hop size. + argmax (bool, optional): Whether to use argmax. Defaults to False. + norm (float, optional): Norm for chroma normalization. Defaults to inf. + """ + def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, nfft: tp.Optional[int] = None, + winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None, argmax: bool = False, + norm: float = torch.inf): + super().__init__() + self.winlen = winlen or 2 ** radix2_exp + self.nfft = nfft or self.winlen + self.winhop = winhop or (self.winlen // 4) + self.sample_rate = sample_rate + self.n_chroma = n_chroma + self.norm = norm + self.argmax = argmax + self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0, + n_chroma=self.n_chroma)), persistent=False) + self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen, + hop_length=self.winhop, power=2, center=True, + pad=0, normalized=True) + + def forward(self, wav: torch.Tensor) -> torch.Tensor: + T = wav.shape[-1] + # in case we are getting a wav that was dropped out (nullified) + # from the conditioner, make sure wav length is no less that nfft + if T < self.nfft: + pad = self.nfft - T + r = 0 if pad % 2 == 0 else 1 + wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0) + assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}" + + spec = self.spec(wav).squeeze(1) + raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec) + norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6) + norm_chroma = rearrange(norm_chroma, 'b d t -> b t d') + + if self.argmax: + idx = norm_chroma.argmax(-1, keepdim=True) + norm_chroma[:] = 0 + norm_chroma.scatter_(dim=-1, index=idx, value=1) + + return norm_chroma diff --git a/audiocraft/modules/codebooks_patterns.py b/audiocraft/modules/codebooks_patterns.py index c5b35cbea8cff84aa56116dbdd860fc72a913a13..61362588403a3eef4a4b1b4ad4595526722da20f 100644 --- a/audiocraft/modules/codebooks_patterns.py +++ b/audiocraft/modules/codebooks_patterns.py @@ -122,7 +122,7 @@ class Pattern: Args: timesteps (int): Maximum number of timesteps steps to consider. keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps. - device (Union[torch.device, str]): Device for created tensors. + device (torch.device or str): Device for created tensors. Returns: indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S]. mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S]. @@ -189,9 +189,9 @@ class Pattern: keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps. Steps that are beyond valid steps will be replaced by the special_token in that case. is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not. - device (Union[torch.device, str]): Device for created tensors. + device (torch.device or str): Device for created tensors. Returns: - torch.Tensor: Indexes for reconstructing the output, of shape [K, T]. + indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T]. mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T]. """ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout @@ -295,7 +295,7 @@ class CodebooksPatternProvider(ABC): """Builds pattern with specific interleaving between codebooks. Args: - timesteps (int): Total numer of timesteps. + timesteps (int): Total number of timesteps. """ raise NotImplementedError() @@ -318,7 +318,7 @@ class DelayedPatternProvider(CodebooksPatternProvider): Args: n_q (int): Number of codebooks. - delays (Optional[List[int]]): Delay for each of the codebooks. + delays (list of int, optional): Delay for each of the codebooks. If delays not defined, each codebook is delayed by 1 compared to the previous one. flatten_first (int): Flatten the first N timesteps. empty_initial (int): Prepend with N empty list of coordinates. @@ -406,10 +406,10 @@ class UnrolledPatternProvider(CodebooksPatternProvider): Args: n_q (int): Number of codebooks. - flattening (Optional[List[int]]): Flattening schema over the codebooks. If not defined, + flattening (list of int, optional): Flattening schema over the codebooks. If not defined, the codebooks will be flattened to 1 codebook per step, meaning that the sequence will have n_q extra steps for each timestep. - delays (Optional[List[int]]): Delay for each of the codebooks. If not defined, + delays (list of int, optional): Delay for each of the codebooks. If not defined, no delay is added and therefore will default to [0] * ``n_q``. Note that two codebooks that will be flattened to the same inner step should have the same delay, otherwise the pattern is considered as invalid. @@ -462,7 +462,7 @@ class UnrolledPatternProvider(CodebooksPatternProvider): """Builds pattern for delay across codebooks. Args: - timesteps (int): Total numer of timesteps. + timesteps (int): Total number of timesteps. """ # the PatternLayout is built as a tuple of sequence position and list of coordinates # so that it can be reordered properly given the required delay between codebooks of given timesteps @@ -486,13 +486,18 @@ class UnrolledPatternProvider(CodebooksPatternProvider): return Pattern(out, n_q=self.n_q, timesteps=timesteps) -class VALLEPattern(CodebooksPatternProvider): - """Almost VALL-E style pattern. We futher allow some delays for the - codebooks other than the first one. +class CoarseFirstPattern(CodebooksPatternProvider): + """First generates all the codebooks #1 (e.g. coarser), then the remaining ones, + potentially with delays. + + ..Warning:: You must always generate the full training duration at test time, for instance, + 30 seconds, as otherwise, the fine codebooks will start being generated in an unexpected + location. This is due to the non causality of the remaining codebooks with respect to + the first ones. Args: n_q (int): Number of codebooks. - delays (Optional[List[int]]): Delay for each of the codebooks. + delays (list of int, optional): Delay for each of the codebooks. If delays not defined, each codebook is delayed by 1 compared to the previous one. """ def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None): diff --git a/audiocraft/modules/conditioners.py b/audiocraft/modules/conditioners.py index 82792316024b88d4c5c38b0a28f443627771d509..178957d1771dc4c6f2df028fd9bb60f204567955 100644 --- a/audiocraft/modules/conditioners.py +++ b/audiocraft/modules/conditioners.py @@ -10,87 +10,61 @@ from dataclasses import dataclass, field from itertools import chain import logging import math +from pathlib import Path import random import re import typing as tp import warnings -from einops import rearrange +import einops from num2words import num2words import spacy -from transformers import T5EncoderModel, T5Tokenizer # type: ignore -import torchaudio +from transformers import RobertaTokenizer, T5EncoderModel, T5Tokenizer # type: ignore import torch from torch import nn -from torch import Tensor import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence +from .chroma import ChromaExtractor from .streaming import StreamingModule from .transformer import create_sin_embedding +from ..data.audio import audio_read from ..data.audio_dataset import SegmentInfo +from ..data.audio_utils import convert_audio +from ..environment import AudioCraftEnvironment +from ..quantization import ResidualVectorQuantizer from ..utils.autocast import TorchAutocast -from ..utils.utils import hash_trick, length_to_mask, collate +from ..utils.cache import EmbeddingCache +from ..utils.utils import collate, hash_trick, length_to_mask, load_clap_state_dict, warn_once logger = logging.getLogger(__name__) TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist) -ConditionType = tp.Tuple[Tensor, Tensor] # condition, mask +ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask class WavCondition(tp.NamedTuple): - wav: Tensor - length: Tensor + wav: torch.Tensor + length: torch.Tensor + sample_rate: tp.List[int] path: tp.List[tp.Optional[str]] = [] + seek_time: tp.List[tp.Optional[float]] = [] -def nullify_condition(condition: ConditionType, dim: int = 1): - """This function transforms an input condition to a null condition. - The way it is done by converting it to a single zero vector similarly - to how it is done inside WhiteSpaceTokenizer and NoopTokenizer. - - Args: - condition (ConditionType): a tuple of condition and mask (tp.Tuple[Tensor, Tensor]) - dim (int): the dimension that will be truncated (should be the time dimension) - WARNING!: dim should not be the batch dimension! - Returns: - ConditionType: a tuple of null condition and mask - """ - assert dim != 0, "dim cannot be the batch dimension!" - assert type(condition) == tuple and \ - type(condition[0]) == Tensor and \ - type(condition[1]) == Tensor, "'nullify_condition' got an unexpected input type!" - cond, mask = condition - B = cond.shape[0] - last_dim = cond.dim() - 1 - out = cond.transpose(dim, last_dim) - out = 0. * out[..., :1] - out = out.transpose(dim, last_dim) - mask = torch.zeros((B, 1), device=out.device).int() - assert cond.dim() == out.dim() - return out, mask - - -def nullify_wav(wav: Tensor) -> WavCondition: - """Create a nullified WavCondition from a wav tensor with appropriate shape. - - Args: - wav (Tensor): tensor of shape [B, T] - Returns: - WavCondition: wav condition with nullified wav. - """ - null_wav, _ = nullify_condition((wav, torch.zeros_like(wav)), dim=wav.dim() - 1) - return WavCondition( - wav=null_wav, - length=torch.tensor([0] * wav.shape[0], device=wav.device), - path=['null_wav'] * wav.shape[0] - ) +class JointEmbedCondition(tp.NamedTuple): + wav: torch.Tensor + text: tp.List[tp.Optional[str]] + length: torch.Tensor + sample_rate: tp.List[int] + path: tp.List[tp.Optional[str]] = [] + seek_time: tp.List[tp.Optional[float]] = [] @dataclass class ConditioningAttributes: text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict) wav: tp.Dict[str, WavCondition] = field(default_factory=dict) + joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict) def __getitem__(self, item): return getattr(self, item) @@ -103,14 +77,23 @@ class ConditioningAttributes: def wav_attributes(self): return self.wav.keys() + @property + def joint_embed_attributes(self): + return self.joint_embed.keys() + @property def attributes(self): - return {"text": self.text_attributes, "wav": self.wav_attributes} + return { + "text": self.text_attributes, + "wav": self.wav_attributes, + "joint_embed": self.joint_embed_attributes, + } def to_flat_dict(self): return { **{f"text.{k}": v for k, v in self.text.items()}, **{f"wav.{k}": v for k, v in self.wav.items()}, + **{f"joint_embed.{k}": v for k, v in self.joint_embed.items()} } @classmethod @@ -131,11 +114,74 @@ class SegmentWithAttributes(SegmentInfo): raise NotImplementedError() +def nullify_condition(condition: ConditionType, dim: int = 1): + """Transform an input condition to a null condition. + The way it is done by converting it to a single zero vector similarly + to how it is done inside WhiteSpaceTokenizer and NoopTokenizer. + + Args: + condition (ConditionType): A tuple of condition and mask (tuple[torch.Tensor, torch.Tensor]) + dim (int): The dimension that will be truncated (should be the time dimension) + WARNING!: dim should not be the batch dimension! + Returns: + ConditionType: A tuple of null condition and mask + """ + assert dim != 0, "dim cannot be the batch dimension!" + assert isinstance(condition, tuple) and \ + isinstance(condition[0], torch.Tensor) and \ + isinstance(condition[1], torch.Tensor), "'nullify_condition' got an unexpected input type!" + cond, mask = condition + B = cond.shape[0] + last_dim = cond.dim() - 1 + out = cond.transpose(dim, last_dim) + out = 0. * out[..., :1] + out = out.transpose(dim, last_dim) + mask = torch.zeros((B, 1), device=out.device).int() + assert cond.dim() == out.dim() + return out, mask + + +def nullify_wav(cond: WavCondition) -> WavCondition: + """Transform a WavCondition to a nullified WavCondition. + It replaces the wav by a null tensor, forces its length to 0, and replaces metadata by dummy attributes. + + Args: + cond (WavCondition): Wav condition with wav, tensor of shape [B, T]. + Returns: + WavCondition: Nullified wav condition. + """ + null_wav, _ = nullify_condition((cond.wav, torch.zeros_like(cond.wav)), dim=cond.wav.dim() - 1) + return WavCondition( + wav=null_wav, + length=torch.tensor([0] * cond.wav.shape[0], device=cond.wav.device), + sample_rate=cond.sample_rate, + path=[None] * cond.wav.shape[0], + seek_time=[None] * cond.wav.shape[0], + ) + + +def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition: + """Nullify the joint embedding condition by replacing it by a null tensor, forcing its length to 0, + and replacing metadata by dummy attributes. + + Args: + cond (JointEmbedCondition): Joint embedding condition with wav and text, wav tensor of shape [B, C, T]. + """ + null_wav, _ = nullify_condition((embed.wav, torch.zeros_like(embed.wav)), dim=embed.wav.dim() - 1) + return JointEmbedCondition( + wav=null_wav, text=[None] * len(embed.text), + length=torch.LongTensor([0]).to(embed.wav.device), + sample_rate=embed.sample_rate, + path=[None] * embed.wav.shape[0], + seek_time=[0] * embed.wav.shape[0], + ) + + class Tokenizer: - """Base class for all tokenizers + """Base tokenizer implementation (in case we want to introduce more advances tokenizers in the future). """ - def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[Tensor, Tensor]: + def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError() @@ -146,7 +192,7 @@ class WhiteSpaceTokenizer(Tokenizer): [[78, 62, 31, 4, 78, 25, 19, 34], [59, 77, 0, 0, 0, 0, 0, 0]] """ - PUNCTUATIONS = "?:!.,;" + PUNCTUATION = "?:!.,;" def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm", lemma: bool = True, stopwords: bool = True) -> None: @@ -161,18 +207,15 @@ class WhiteSpaceTokenizer(Tokenizer): self.nlp = spacy.load(language) @tp.no_type_check - def __call__( - self, - texts: tp.List[tp.Optional[str]], - return_text: bool = False - ) -> tp.Tuple[Tensor, Tensor]: + def __call__(self, texts: tp.List[tp.Optional[str]], + return_text: bool = False) -> tp.Tuple[torch.Tensor, torch.Tensor]: """Take a list of strings and convert them to a tensor of indices. Args: - texts (tp.List[str]): List of strings. + texts (list[str]): List of strings. return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False. Returns: - tp.Tuple[Tensor, Tensor]: + tuple[torch.Tensor, torch.Tensor]: - Indices of words in the LUT. - And a mask indicating where the padding tokens are """ @@ -181,7 +224,7 @@ class WhiteSpaceTokenizer(Tokenizer): for i, text in enumerate(texts): # if current sample doesn't have a certain attribute, replace with pad token if text is None: - output.append(Tensor([self.pad_idx])) + output.append(torch.Tensor([self.pad_idx])) lengths.append(0) continue @@ -192,15 +235,15 @@ class WhiteSpaceTokenizer(Tokenizer): # remove stopwords if self.stopwords: text = [w for w in text if not w.is_stop] # type: ignore - # remove punctuations - text = [w for w in text if w.text not in self.PUNCTUATIONS] # type: ignore + # remove punctuation + text = [w for w in text if w.text not in self.PUNCTUATION] # type: ignore # lemmatize if needed text = [getattr(t, "lemma_" if self.lemma else "text") for t in text] # type: ignore texts[i] = " ".join(text) lengths.append(len(text)) # convert to tensor - tokens = Tensor([hash_trick(w, self.n_bins) for w in text]) + tokens = torch.Tensor([hash_trick(w, self.n_bins) for w in text]) output.append(tokens) mask = length_to_mask(torch.IntTensor(lengths)).int() @@ -224,7 +267,7 @@ class NoopTokenizer(Tokenizer): self.n_bins = n_bins self.pad_idx = pad_idx - def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[Tensor, Tensor]: + def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]: output, lengths = [], [] for text in texts: # if current sample doesn't have a certain attribute, replace with pad token @@ -241,15 +284,16 @@ class NoopTokenizer(Tokenizer): class BaseConditioner(nn.Module): - """Base model for all conditioner modules. We allow the output dim to be different - than the hidden dim for two reasons: 1) keep our LUTs small when the vocab is large; + """Base model for all conditioner modules. + We allow the output dim to be different than the hidden dim for two reasons: + 1) keep our LUTs small when the vocab is large; 2) make all condition dims consistent. Args: - dim (int): Hidden dim of the model (text-encoder/LUT). + dim (int): Hidden dim of the model. output_dim (int): Output dim of the conditioner. """ - def __init__(self, dim, output_dim): + def __init__(self, dim: int, output_dim: int): super().__init__() self.dim = dim self.output_dim = output_dim @@ -294,9 +338,9 @@ class LUTConditioner(TextConditioner): super().__init__(dim, output_dim) self.embed = nn.Embedding(n_bins, dim) self.tokenizer: Tokenizer - if tokenizer == "whitespace": + if tokenizer == 'whitespace': self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx) - elif tokenizer == "noop": + elif tokenizer == 'noop': self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx) else: raise ValueError(f"unrecognized tokenizer `{tokenizer}`.") @@ -346,13 +390,12 @@ class T5Conditioner(TextConditioner): def __init__(self, name: str, output_dim: int, finetune: bool, device: str, autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0., normalize_text: bool = False): - assert name in self.MODELS, f"unrecognized t5 model name (should in {self.MODELS})" + assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})" super().__init__(self.MODELS_DIMS[name], output_dim) self.device = device self.name = name self.finetune = finetune self.word_dropout = word_dropout - if autocast_dtype is None or self.device == 'cpu': self.autocast = TorchAutocast(enabled=False) if self.device != 'cpu': @@ -378,7 +421,7 @@ class T5Conditioner(TextConditioner): else: # this makes sure that the t5 models is not part # of the saved checkpoint - self.__dict__["t5"] = t5.to(device) + self.__dict__['t5'] = t5.to(device) self.normalize_text = normalize_text if normalize_text: @@ -398,13 +441,13 @@ class T5Conditioner(TextConditioner): empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""]) - inputs = self.t5_tokenizer(entries, return_tensors="pt", padding=True).to(self.device) - mask = inputs["attention_mask"] + inputs = self.t5_tokenizer(entries, return_tensors='pt', padding=True).to(self.device) + mask = inputs['attention_mask'] mask[empty_idx, :] = 0 # zero-out index where the input is non-existant return inputs def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType: - mask = inputs["attention_mask"] + mask = inputs['attention_mask'] with torch.set_grad_enabled(self.finetune), self.autocast: embeds = self.t5(**inputs).last_hidden_state embeds = self.output_proj(embeds.to(self.output_proj.weight)) @@ -426,204 +469,558 @@ class WaveformConditioner(BaseConditioner): def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]): super().__init__(dim, output_dim) self.device = device + # if False no masking is done, used in ChromaStemConditioner when completing by periodicity a sample. + self._use_masking = True - def tokenize(self, wav_length: WavCondition) -> WavCondition: - wav, length, path = wav_length + def tokenize(self, x: WavCondition) -> WavCondition: + wav, length, sample_rate, path, seek_time = x assert length is not None - return WavCondition(wav.to(self.device), length.to(self.device), path) + return WavCondition(wav.to(self.device), length.to(self.device), sample_rate, path, seek_time) - def _get_wav_embedding(self, wav: Tensor) -> Tensor: - """Gets as input a wav and returns a dense vector of conditions.""" + def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor: + """Gets as input a WavCondition and returns a dense embedding.""" raise NotImplementedError() def _downsampling_factor(self): """Returns the downsampling factor of the embedding model.""" raise NotImplementedError() - def forward(self, inputs: WavCondition) -> ConditionType: - """ + def forward(self, x: WavCondition) -> ConditionType: + """Extract condition embedding and mask from a waveform and its metadata. Args: - input (WavCondition): Tuple of (waveform, lengths). + x (WavCondition): Waveform condition containing raw waveform and metadata. Returns: - ConditionType: Dense vector representing the conditioning along with its' mask. + ConditionType: a dense vector representing the conditioning along with its mask """ - wav, lengths, path = inputs + wav, lengths, *_ = x with torch.no_grad(): - embeds = self._get_wav_embedding(wav) + embeds = self._get_wav_embedding(x) embeds = embeds.to(self.output_proj.weight) embeds = self.output_proj(embeds) - if lengths is not None: + if lengths is not None and self._use_masking: lengths = lengths / self._downsampling_factor() mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore else: - mask = torch.ones_like(embeds) - embeds = (embeds * mask.unsqueeze(2).to(self.device)) - + mask = torch.ones_like(embeds[..., 0]) + embeds = (embeds * mask.unsqueeze(-1)) return embeds, mask class ChromaStemConditioner(WaveformConditioner): - """Chroma conditioner that uses DEMUCS to first filter out drums and bass. The is followed by - the insight the drums and bass often dominate the chroma, leading to the chroma not containing the - information about melody. + """Chroma conditioner based on stems. + The ChromaStemConditioner uses DEMUCS to first filter out drums and bass, as + the drums and bass often dominate the chroma leading to the chroma features + not containing information about the melody. Args: output_dim (int): Output dimension for the conditioner. sample_rate (int): Sample rate for the chroma extractor. - n_chroma (int): Number of chroma for the chroma extractor. - radix2_exp (int): Radix2 exponent for the chroma extractor. - duration (float): Duration used during training. This is later used for correct padding + n_chroma (int): Number of chroma bins for the chroma extractor. + radix2_exp (int): Size of stft window for the chroma extractor (power of 2, e.g. 12 -> 2^12). + duration (int): duration used during training. This is later used for correct padding in case we are using chroma as prefix. - match_len_on_eval (bool, optional): If True then all chromas are padded to the training + match_len_on_eval (bool, optional): if True then all chromas are padded to the training duration. Defaults to False. - eval_wavs (str, optional): Path to a json egg with waveform, this waveforms are used as + eval_wavs (str, optional): path to a dataset manifest with waveform, this waveforms are used as conditions during eval (for cases where we don't want to leak test conditions like MusicCaps). Defaults to None. - n_eval_wavs (int, optional): Limits the number of waveforms used for conditioning. Defaults to 0. + n_eval_wavs (int, optional): limits the number of waveforms used for conditioning. Defaults to 0. device (tp.Union[torch.device, str], optional): Device for the conditioner. **kwargs: Additional parameters for the chroma extractor. """ def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int, duration: float, match_len_on_eval: bool = True, eval_wavs: tp.Optional[str] = None, - n_eval_wavs: int = 0, device: tp.Union[torch.device, str] = "cpu", **kwargs): + n_eval_wavs: int = 0, cache_path: tp.Optional[tp.Union[str, Path]] = None, + device: tp.Union[torch.device, str] = 'cpu', **kwargs): from demucs import pretrained super().__init__(dim=n_chroma, output_dim=output_dim, device=device) - self.autocast = TorchAutocast(enabled=device != "cpu", device_type=self.device, dtype=torch.float32) + self.autocast = TorchAutocast(enabled=device != 'cpu', device_type=self.device, dtype=torch.float32) self.sample_rate = sample_rate self.match_len_on_eval = match_len_on_eval + if match_len_on_eval: + self._use_masking = False self.duration = duration - self.__dict__["demucs"] = pretrained.get_model('htdemucs').to(device) - self.stem2idx = {'drums': 0, 'bass': 1, 'other': 2, 'vocal': 3} - self.stem_idx = torch.LongTensor([self.stem2idx['vocal'], self.stem2idx['other']]).to(device) - self.chroma = ChromaExtractor(sample_rate=sample_rate, n_chroma=n_chroma, radix2_exp=radix2_exp, - device=device, **kwargs) + self.__dict__['demucs'] = pretrained.get_model('htdemucs').to(device) + stem_sources: list = self.demucs.sources # type: ignore + self.stem_indices = torch.LongTensor([stem_sources.index('vocals'), stem_sources.index('other')]).to(device) + self.chroma = ChromaExtractor(sample_rate=sample_rate, n_chroma=n_chroma, + radix2_exp=radix2_exp, **kwargs).to(device) self.chroma_len = self._get_chroma_len() - - def _downsampling_factor(self): + self.eval_wavs: tp.Optional[torch.Tensor] = self._load_eval_wavs(eval_wavs, n_eval_wavs) + self.cache = None + if cache_path is not None: + self.cache = EmbeddingCache(Path(cache_path) / 'wav', self.device, + compute_embed_fn=self._get_full_chroma_for_cache, + extract_embed_fn=self._extract_chroma_chunk) + + def _downsampling_factor(self) -> int: return self.chroma.winhop - def _get_chroma_len(self): - """Get length of chroma during training""" - dummy_wav = torch.zeros((1, self.sample_rate * self.duration), device=self.device) + def _load_eval_wavs(self, path: tp.Optional[str], num_samples: int) -> tp.Optional[torch.Tensor]: + """Load pre-defined waveforms from a json. + These waveforms will be used for chroma extraction during evaluation. + This is done to make the evaluation on MusicCaps fair (we shouldn't see the chromas of MusicCaps). + """ + if path is None: + return None + + logger.info(f"Loading evaluation wavs from {path}") + from audiocraft.data.audio_dataset import AudioDataset + dataset: AudioDataset = AudioDataset.from_meta( + path, segment_duration=self.duration, min_audio_duration=self.duration, + sample_rate=self.sample_rate, channels=1) + + if len(dataset) > 0: + eval_wavs = dataset.collater([dataset[i] for i in range(num_samples)]).to(self.device) + logger.info(f"Using {len(eval_wavs)} evaluation wavs for chroma-stem conditioner") + return eval_wavs + else: + raise ValueError("Could not find evaluation wavs, check lengths of wavs") + + def reset_eval_wavs(self, eval_wavs: tp.Optional[torch.Tensor]) -> None: + self.eval_wavs = eval_wavs + + def has_eval_wavs(self) -> bool: + return self.eval_wavs is not None + + def _sample_eval_wavs(self, num_samples: int) -> torch.Tensor: + """Sample wavs from a predefined list.""" + assert self.eval_wavs is not None, "Cannot sample eval wavs as no eval wavs provided." + total_eval_wavs = len(self.eval_wavs) + out = self.eval_wavs + if num_samples > total_eval_wavs: + out = self.eval_wavs.repeat(num_samples // total_eval_wavs + 1, 1, 1) + return out[torch.randperm(len(out))][:num_samples] + + def _get_chroma_len(self) -> int: + """Get length of chroma during training.""" + dummy_wav = torch.zeros((1, int(self.sample_rate * self.duration)), device=self.device) dummy_chr = self.chroma(dummy_wav) return dummy_chr.shape[1] @torch.no_grad() - def _get_filtered_wav(self, wav): + def _get_stemmed_wav(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor: + """Get parts of the wav that holds the melody, extracting the main stems from the wav.""" from demucs.apply import apply_model from demucs.audio import convert_audio with self.autocast: - wav = convert_audio(wav, self.sample_rate, self.demucs.samplerate, self.demucs.audio_channels) + wav = convert_audio( + wav, sample_rate, self.demucs.samplerate, self.demucs.audio_channels) # type: ignore stems = apply_model(self.demucs, wav, device=self.device) - stems = stems[:, self.stem_idx] # extract stem - stems = stems.sum(1) # merge extracted stems - stems = stems.mean(1, keepdim=True) # mono - stems = convert_audio(stems, self.demucs.samplerate, self.sample_rate, 1) - return stems + stems = stems[:, self.stem_indices] # extract relevant stems for melody conditioning + mix_wav = stems.sum(1) # merge extracted stems to single waveform + mix_wav = convert_audio(mix_wav, self.demucs.samplerate, self.sample_rate, 1) # type: ignore + return mix_wav @torch.no_grad() - def _get_wav_embedding(self, wav): + def _extract_chroma(self, wav: torch.Tensor) -> torch.Tensor: + """Extract chroma features from the waveform.""" + with self.autocast: + return self.chroma(wav) + + @torch.no_grad() + def _compute_wav_embedding(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor: + """Compute wav embedding, applying stem and chroma extraction.""" # avoid 0-size tensors when we are working with null conds if wav.shape[-1] == 1: - return self.chroma(wav) - stems = self._get_filtered_wav(wav) - chroma = self.chroma(stems) + return self._extract_chroma(wav) + stems = self._get_stemmed_wav(wav, sample_rate) + chroma = self._extract_chroma(stems) + return chroma + + @torch.no_grad() + def _get_full_chroma_for_cache(self, path: tp.Union[str, Path], x: WavCondition, idx: int) -> torch.Tensor: + """Extract chroma from the whole audio waveform at the given path.""" + wav, sr = audio_read(path) + wav = wav[None].to(self.device) + wav = convert_audio(wav, sr, self.sample_rate, to_channels=1) + chroma = self._compute_wav_embedding(wav, self.sample_rate)[0] + return chroma + + def _extract_chroma_chunk(self, full_chroma: torch.Tensor, x: WavCondition, idx: int) -> torch.Tensor: + """Extract a chunk of chroma from the full chroma derived from the full waveform.""" + wav_length = x.wav.shape[-1] + seek_time = x.seek_time[idx] + assert seek_time is not None, ( + "WavCondition seek_time is required " + "when extracting chroma chunks from pre-computed chroma.") + full_chroma = full_chroma.float() + frame_rate = self.sample_rate / self._downsampling_factor() + target_length = int(frame_rate * wav_length / self.sample_rate) + index = int(frame_rate * seek_time) + out = full_chroma[index: index + target_length] + out = F.pad(out[None], (0, 0, 0, target_length - out.shape[0]))[0] + return out.to(self.device) + + @torch.no_grad() + def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor: + """Get the wav embedding from the WavCondition. + The conditioner will either extract the embedding on-the-fly computing it from the condition wav directly + or will rely on the embedding cache to load the pre-computed embedding if relevant. + """ + sampled_wav: tp.Optional[torch.Tensor] = None + if not self.training and self.eval_wavs is not None: + warn_once(logger, "Using precomputed evaluation wavs!") + sampled_wav = self._sample_eval_wavs(len(x.wav)) + + no_undefined_paths = all(p is not None for p in x.path) + no_nullified_cond = x.wav.shape[-1] > 1 + if sampled_wav is not None: + chroma = self._compute_wav_embedding(sampled_wav, self.sample_rate) + elif self.cache is not None and no_undefined_paths and no_nullified_cond: + paths = [Path(p) for p in x.path if p is not None] + chroma = self.cache.get_embed_from_cache(paths, x) + else: + assert all(sr == x.sample_rate[0] for sr in x.sample_rate), "All sample rates in batch should be equal." + chroma = self._compute_wav_embedding(x.wav, x.sample_rate[0]) if self.match_len_on_eval: - b, t, c = chroma.shape - if t > self.chroma_len: + B, T, C = chroma.shape + if T > self.chroma_len: chroma = chroma[:, :self.chroma_len] - logger.debug(f'chroma was truncated! ({t} -> {chroma.shape[1]})') - elif t < self.chroma_len: - # chroma = F.pad(chroma, (0, 0, 0, self.chroma_len - t)) - n_repeat = int(math.ceil(self.chroma_len / t)) + logger.debug(f"Chroma was truncated to match length! ({T} -> {chroma.shape[1]})") + elif T < self.chroma_len: + n_repeat = int(math.ceil(self.chroma_len / T)) chroma = chroma.repeat(1, n_repeat, 1) chroma = chroma[:, :self.chroma_len] - logger.debug(f'chroma was zero-padded! ({t} -> {chroma.shape[1]})') + logger.debug(f"Chroma was repeated to match length! ({T} -> {chroma.shape[1]})") + return chroma + def tokenize(self, x: WavCondition) -> WavCondition: + """Apply WavConditioner tokenization and populate cache if needed.""" + x = super().tokenize(x) + no_undefined_paths = all(p is not None for p in x.path) + if self.cache is not None and no_undefined_paths: + paths = [Path(p) for p in x.path if p is not None] + self.cache.populate_embed_cache(paths, x) + return x -class ChromaExtractor(nn.Module): - """Chroma extraction class, handles chroma extraction and quantization. + +class JointEmbeddingConditioner(BaseConditioner): + """Joint embedding conditioning supporting both audio or text conditioning. Args: - sample_rate (int): Sample rate. - n_chroma (int): Number of chroma to consider. - radix2_exp (int): Radix2 exponent. - nfft (tp.Optional[int], optional): Number of FFT. - winlen (tp.Optional[int], optional): Window length. - winhop (tp.Optional[int], optional): Window hop size. - argmax (bool, optional): Whether to use argmax. Defaults to False. - norm (float, optional): Norm for chroma normalization. Defaults to inf. - device (tp.Union[torch.device, str], optional): Device to use. Defaults to cpu. + dim (int): Dimension. + output_dim (int): Output dimension. + device (str): Device. + attribute (str): Attribute used by the conditioner. + autocast_dtype (str): Autocast for the conditioner. + quantize (bool): Whether to quantize the CLAP embedding. + n_q (int): Number of residual quantizers (used if quantize is true). + bins (int): Quantizers' codebooks size (used if quantize is true). + kwargs: Additional parameters for residual vector quantizer. """ - def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, - nfft: tp.Optional[int] = None, winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None, - argmax: bool = False, norm: float = torch.inf, device: tp.Union[torch.device, str] = "cpu"): - super().__init__() - from librosa import filters + def __init__(self, dim: int, output_dim: int, device: str, attribute: str, + autocast_dtype: tp.Optional[str] = 'float32', quantize: bool = True, + n_q: int = 12, bins: int = 1024, **kwargs): + super().__init__(dim=dim, output_dim=output_dim) self.device = device - self.autocast = TorchAutocast(enabled=device != "cpu", device_type=self.device, dtype=torch.float32) - self.winlen = winlen or 2 ** radix2_exp - self.nfft = nfft or self.winlen - self.winhop = winhop or (self.winlen // 4) - self.sr = sample_rate - self.n_chroma = n_chroma - self.norm = norm - self.argmax = argmax - self.window = torch.hann_window(self.winlen).to(device) - self.fbanks = torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0, - n_chroma=self.n_chroma)).to(device) - self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen, - hop_length=self.winhop, power=2, center=True, - pad=0, normalized=True).to(device) - - def forward(self, wav): + self.attribute = attribute + if autocast_dtype is None or device == 'cpu': + self.autocast = TorchAutocast(enabled=False) + logger.warning("JointEmbeddingConditioner has no autocast, this might lead to NaN.") + else: + dtype = getattr(torch, autocast_dtype) + assert isinstance(dtype, torch.dtype) + logger.info(f"JointEmbeddingConditioner will be evaluated with autocast as {autocast_dtype}.") + self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype) + # residual vector quantizer to discretize the conditioned embedding + self.quantizer: tp.Optional[ResidualVectorQuantizer] = None + if quantize: + self.quantizer = ResidualVectorQuantizer(dim, n_q=n_q, bins=bins, **kwargs) + + def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]: + """Get joint embedding in latent space from the inputs. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Tensor for the latent embedding + and corresponding empty indexes. + """ + raise NotImplementedError() + + def forward(self, x: JointEmbedCondition) -> ConditionType: with self.autocast: - T = wav.shape[-1] - # in case we are getting a wav that was dropped out (nullified) - # make sure wav length is no less that nfft - if T < self.nfft: - pad = self.nfft - T - r = 0 if pad % 2 == 0 else 1 - wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0) - assert wav.shape[-1] == self.nfft, f'expected len {self.nfft} but got {wav.shape[-1]}' - spec = self.spec(wav).squeeze(1) - raw_chroma = torch.einsum("cf,...ft->...ct", self.fbanks, spec) - norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6) - norm_chroma = rearrange(norm_chroma, "b d t -> b t d") - - if self.argmax: - idx = norm_chroma.argmax(-1, keepdims=True) - norm_chroma[:] = 0 - norm_chroma.scatter_(dim=-1, index=idx, value=1) - - return norm_chroma - - -def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str): + embed, empty_idx = self._get_embed(x) + if self.quantizer is not None: + embed = embed.view(-1, self.dim, 1) + q_res = self.quantizer(embed, frame_rate=1) + out_embed = q_res.x.view(-1, self.dim) + else: + out_embed = embed + out_embed = self.output_proj(out_embed).view(-1, 1, self.output_dim) + mask = torch.ones(*out_embed.shape[:2], device=out_embed.device) + mask[empty_idx, :] = 0 # zero-out index where the input is non-existant + out_embed = (out_embed * mask.unsqueeze(-1)) + return out_embed, mask + + def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition: + return x + + +class CLAPEmbeddingConditioner(JointEmbeddingConditioner): + """Joint Embedding conditioner based on pre-trained CLAP model. + + This CLAP-based conditioner supports a caching mechanism + over the computed embeddings for faster training. + + Args: + dim (int): Dimension. + output_dim (int): Output dimension. + device (str): Device. + attribute (str): Attribute used by the conditioner. + quantize (bool): Whether to quantize the CLAP embedding. + n_q (int): Number of residual quantizers (used if quantize is true). + bins (int): Quantizers' codebooks size (used if quantize is true). + checkpoint (str): Path to CLAP checkpoint. + model_arch (str): CLAP model architecture. + enable_fusion (bool): Enable fusion for CLAP model. + sample_rate (int): Sample rate used by CLAP model. + max_audio_length (float): Maximum audio length for CLAP model. + audio_stride (float): Stride to use for getting a CLAP embedding on the full sequence. + normalize (bool): Whether to normalize the CLAP embedding. + text_p (float): Probability of using text representation instead of audio at train time. + batch_size (Optional[int]): Batch size for CLAP embedding computation. + autocast_dtype (str): Autocast for the conditioner. + cache_path (Optional[str]): Path for pre-computed embeddings caching. + kwargs: Additional parameters for residual vector quantizer. + """ + def __init__(self, dim: int, output_dim: int, device: str, attribute: str, + quantize: bool, n_q: int, bins: int, checkpoint: tp.Union[str, Path], model_arch: str, + enable_fusion: bool, sample_rate: int, max_audio_length: int, audio_stride: int, + normalize: bool, text_p: bool, batch_size: tp.Optional[int] = None, + autocast_dtype: tp.Optional[str] = 'float32', cache_path: tp.Optional[str] = None, **kwargs): + try: + import laion_clap # type: ignore + except ImportError: + raise ImportError("Please install CLAP to use the CLAPEmbeddingConditioner: 'pip install laion_clap'") + warnings.warn("Sample rate for CLAP conditioner was fixed in version v1.1.0, (from 44.1 to 48 kHz). " + "Please retrain all models.") + checkpoint = AudioCraftEnvironment.resolve_reference_path(checkpoint) + clap_tokenize = RobertaTokenizer.from_pretrained('roberta-base') + clap_model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch) + load_clap_state_dict(clap_model, checkpoint) + clap_model.eval() + clap_model.to(device) + super().__init__(dim=dim, output_dim=output_dim, device=device, attribute=attribute, + autocast_dtype=autocast_dtype, quantize=quantize, n_q=n_q, bins=bins, + **kwargs) + self.checkpoint = checkpoint + self.enable_fusion = enable_fusion + self.model_arch = model_arch + self.clap: laion_clap.CLAP_Module + self.clap_tokenize: RobertaTokenizer + self.clap_sample_rate = sample_rate + self.clap_max_frames = int(self.clap_sample_rate * max_audio_length) + self.clap_stride = int(self.clap_sample_rate * audio_stride) + self.batch_size = batch_size or 1 + self.normalize = normalize + self.text_p = text_p + self.__dict__['clap_tokenize'] = clap_tokenize + self.__dict__['clap'] = clap_model + self.wav_cache, self.text_cache = None, None + if cache_path is not None: + self.wav_cache = EmbeddingCache(Path(cache_path) / 'wav', self.device, + compute_embed_fn=self._get_wav_embedding_for_cache, + extract_embed_fn=self._extract_wav_embedding_chunk) + self.text_cache = EmbeddingCache(Path(cache_path) / 'text', self.device, + compute_embed_fn=self._get_text_embedding_for_cache) + + def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict: + # we use the default params from CLAP module here as well + return self.clap_tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt") + + def _compute_text_embedding(self, text: tp.List[str]) -> torch.Tensor: + """Compute text embedding from CLAP model on a given a batch of text. + + Args: + text (list[str]): List of text for the batch, with B items. + Returns: + torch.Tensor: CLAP embedding derived from text, of shape [B, 1, D], with D the CLAP embedding dimension. + """ + with torch.no_grad(): + embed = self.clap.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True) + return embed.view(embed.size(0), 1, embed.size(-1)) + + def _get_text_embedding_for_cache(self, path: tp.Union[Path, str], + x: JointEmbedCondition, idx: int) -> torch.Tensor: + """Get text embedding function for the cache.""" + text = x.text[idx] + text = text if text is not None else "" + return self._compute_text_embedding([text])[0] + + def _preprocess_wav(self, wav: torch.Tensor, length: torch.Tensor, sample_rates: tp.List[int]) -> torch.Tensor: + """Preprocess wav to expected format by CLAP model. + + Args: + wav (torch.Tensor): Audio wav, of shape [B, C, T]. + length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B]. + sample_rates (list[int]): Sample rates for each sample in the batch + Returns: + torch.Tensor: Audio wav of shape [B, T]. + """ + assert wav.dim() == 3, "Expecting wav to be [B, C, T]" + if sample_rates is not None: + _wav = [] + for i, audio in enumerate(wav): + sr = sample_rates[i] + audio = convert_audio(audio, from_rate=sr, to_rate=self.clap_sample_rate, to_channels=1) + _wav.append(audio) + wav = torch.stack(_wav, dim=0) + wav = wav.mean(dim=1) + return wav + + def _compute_wav_embedding(self, wav: torch.Tensor, length: torch.Tensor, + sample_rates: tp.List[int], reduce_mean: bool = False) -> torch.Tensor: + """Compute audio wave embedding from CLAP model. + + Since CLAP operates on a fixed sequence length audio inputs and we need to process longer audio sequences, + we calculate the wav embeddings on `clap_max_frames` windows with `clap_stride`-second stride and + average the resulting embeddings. + + Args: + wav (torch.Tensor): Audio wav, of shape [B, C, T]. + length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B]. + sample_rates (list[int]): Sample rates for each sample in the batch. + reduce_mean (bool): Whether to get the average tensor. + Returns: + torch.Tensor: Audio embedding of shape [B, F, D], F being the number of chunks, D the dimension. + """ + with torch.no_grad(): + wav = self._preprocess_wav(wav, length, sample_rates) + B, T = wav.shape + if T >= self.clap_max_frames: + wav = wav.unfold(-1, self.clap_max_frames, self.clap_stride) # [B, F, T] + else: + wav = wav.view(-1, 1, T) # [B, F, T] with F=1 + wav = einops.rearrange(wav, 'b f t -> (b f) t') + embed_list = [] + for i in range(0, wav.size(0), self.batch_size): + _wav = wav[i:i+self.batch_size, ...] + _embed = self.clap.get_audio_embedding_from_data(_wav, use_tensor=True) + embed_list.append(_embed) + embed = torch.cat(embed_list, dim=0) + embed = einops.rearrange(embed, '(b f) d -> b f d', b=B) + if reduce_mean: + embed = embed.mean(dim=1, keepdim=True) + return embed # [B, F, D] with F=1 if reduce_mean is True + + def _get_wav_embedding_for_cache(self, path: tp.Union[str, Path], + x: JointEmbedCondition, idx: int) -> torch.Tensor: + """Compute audio wave embedding for the cache. + The embedding is computed on a given audio read from file. + + Args: + path (str or Path): Path to the full audio file. + Returns: + torch.Tensor: Single-item tensor of shape [F, D], F being the number of chunks, D the dimension. + """ + wav, sr = audio_read(path) # [C, T] + wav = wav.unsqueeze(0).to(self.device) # [1, C, T] + wav_len = torch.LongTensor([wav.shape[-1]]).to(self.device) + embed = self._compute_wav_embedding(wav, wav_len, [sr], reduce_mean=False) # [B, F, D] + return embed.squeeze(0) # [F, D] + + def _extract_wav_embedding_chunk(self, full_embed: torch.Tensor, x: JointEmbedCondition, idx: int) -> torch.Tensor: + """Extract the chunk of embedding matching the seek_time and length from the full CLAP audio embedding. + + Args: + full_embed (torch.Tensor): CLAP embedding computed on the full wave, of shape [F, D]. + x (JointEmbedCondition): Joint embedding condition for the full batch. + idx (int): Index considered for the given embedding to extract. + Returns: + torch.Tensor: Wav embedding averaged on sliding window, of shape [1, D]. + """ + sample_rate = x.sample_rate[idx] + seek_time = x.seek_time[idx] + seek_time = 0. if seek_time is None else seek_time + clap_stride = int(self.clap_stride / self.clap_sample_rate) * sample_rate + end_seek_time = seek_time + self.clap_max_frames / self.clap_sample_rate + start_offset = int(seek_time * sample_rate // clap_stride) + end_offset = int(end_seek_time * sample_rate // clap_stride) + wav_embed = full_embed[start_offset:end_offset, ...] + wav_embed = wav_embed.mean(dim=0, keepdim=True) + return wav_embed.to(self.device) # [F, D] + + def _get_text_embedding(self, x: JointEmbedCondition) -> torch.Tensor: + """Get CLAP embedding from a batch of text descriptions.""" + no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout + if self.text_cache is not None and no_nullified_cond: + assert all(p is not None for p in x.path), "Cache requires all JointEmbedCondition paths to be provided" + paths = [Path(p) for p in x.path if p is not None] + embed = self.text_cache.get_embed_from_cache(paths, x) + else: + text = [xi if xi is not None else "" for xi in x.text] + embed = self._compute_text_embedding(text) + if self.normalize: + embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1) + return embed + + def _get_wav_embedding(self, x: JointEmbedCondition) -> torch.Tensor: + """Get CLAP embedding from a batch of audio tensors (and corresponding sample rates).""" + no_undefined_paths = all(p is not None for p in x.path) + no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout + if self.wav_cache is not None and no_undefined_paths and no_nullified_cond: + paths = [Path(p) for p in x.path if p is not None] + embed = self.wav_cache.get_embed_from_cache(paths, x) + else: + embed = self._compute_wav_embedding(x.wav, x.length, x.sample_rate, reduce_mean=True) + if self.normalize: + embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1) + return embed + + def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition: + # Trying to limit as much as possible sync points when the cache is warm. + no_undefined_paths = all(p is not None for p in x.path) + if self.wav_cache is not None and no_undefined_paths: + assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided" + paths = [Path(p) for p in x.path if p is not None] + self.wav_cache.populate_embed_cache(paths, x) + if self.text_cache is not None and no_undefined_paths: + assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided" + paths = [Path(p) for p in x.path if p is not None] + self.text_cache.populate_embed_cache(paths, x) + return x + + def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]: + """Extract shared latent representation from either the wav or the text using CLAP.""" + # decide whether to use text embedding at train time or not + use_text_embed = random.random() < self.text_p + if self.training and not use_text_embed: + embed = self._get_wav_embedding(x) + empty_idx = torch.LongTensor([]) # we assume we always have the audio wav + else: + embed = self._get_text_embedding(x) + empty_idx = torch.LongTensor([i for i, xi in enumerate(x.text) if xi is None or xi == ""]) + return embed, empty_idx + + +def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) -> ConditioningAttributes: """Utility function for nullifying an attribute inside an ConditioningAttributes object. - If the condition is of type "wav", then nullify it using "nullify_condition". - If the condition is of any other type, set its' value to None. + If the condition is of type "wav", then nullify it using `nullify_condition` function. + If the condition is of any other type, set its value to None. Works in-place. """ - if condition_type not in ["text", "wav"]: + if condition_type not in ['text', 'wav', 'joint_embed']: raise ValueError( "dropout_condition got an unexpected condition type!" - f" expected 'wav' or 'text' but got '{condition_type}'" + f" expected 'text', 'wav' or 'joint_embed' but got '{condition_type}'" ) if condition not in getattr(sample, condition_type): raise ValueError( "dropout_condition received an unexpected condition!" f" expected wav={sample.wav.keys()} and text={sample.text.keys()}" - f"but got '{condition}' of type '{condition_type}'!" + f" but got '{condition}' of type '{condition_type}'!" ) - if condition_type == "wav": - wav, length, path = sample.wav[condition] - sample.wav[condition] = nullify_wav(wav) + if condition_type == 'wav': + wav_cond = sample.wav[condition] + sample.wav[condition] = nullify_wav(wav_cond) + elif condition_type == 'joint_embed': + embed = sample.joint_embed[condition] + sample.joint_embed[condition] = nullify_joint_embed(embed) else: sample.text[condition] = None @@ -631,7 +1028,7 @@ def dropout_condition(sample: ConditioningAttributes, condition_type: str, condi class DropoutModule(nn.Module): - """Base class for all dropout modules.""" + """Base module for all dropout modules.""" def __init__(self, seed: int = 1234): super().__init__() self.rng = torch.Generator() @@ -639,10 +1036,11 @@ class DropoutModule(nn.Module): class AttributeDropout(DropoutModule): - """Applies dropout with a given probability per attribute. This is different from the behavior of - ClassifierFreeGuidanceDropout as this allows for attributes to be dropped out separately. For example, - "artist" can be dropped while "genre" remains. This is in contrast to ClassifierFreeGuidanceDropout - where if "artist" is dropped "genre" must also be dropped. + """Dropout with a given probability per attribute. + This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes + to be dropped out separately. For example, "artist" can be dropped while "genre" remains. + This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre" + must also be dropped. Args: p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example: @@ -665,21 +1063,19 @@ class AttributeDropout(DropoutModule): def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]: """ Args: - samples (tp.List[ConditioningAttributes]): List of conditions. + samples (list[ConditioningAttributes]): List of conditions. Returns: - tp.List[ConditioningAttributes]: List of conditions after certain attributes were set to None. + list[ConditioningAttributes]: List of conditions after certain attributes were set to None. """ if not self.training and not self.active_on_eval: return samples samples = deepcopy(samples) - for condition_type, ps in self.p.items(): # for condition types [text, wav] for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre]) if torch.rand(1, generator=self.rng).item() < p: for sample in samples: dropout_condition(sample, condition_type, condition) - return samples def __repr__(self): @@ -687,8 +1083,8 @@ class AttributeDropout(DropoutModule): class ClassifierFreeGuidanceDropout(DropoutModule): - """Applies Classifier Free Guidance dropout, meaning all attributes - are dropped with the same probability. + """Classifier Free Guidance dropout. + All attributes are dropped with the same probability. Args: p (float): Probability to apply condition dropout during training. @@ -701,9 +1097,9 @@ class ClassifierFreeGuidanceDropout(DropoutModule): def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]: """ Args: - samples (tp.List[ConditioningAttributes]): List of conditions. + samples (list[ConditioningAttributes]): List of conditions. Returns: - tp.List[ConditioningAttributes]: List of conditions after all attributes were set to None. + list[ConditioningAttributes]: List of conditions after all attributes were set to None. """ if not self.training: return samples @@ -715,12 +1111,10 @@ class ClassifierFreeGuidanceDropout(DropoutModule): # nullify conditions of all attributes samples = deepcopy(samples) - for condition_type in ["wav", "text"]: for sample in samples: for condition in sample.attributes[condition_type]: dropout_condition(sample, condition_type, condition) - return samples def __repr__(self): @@ -728,29 +1122,25 @@ class ClassifierFreeGuidanceDropout(DropoutModule): class ConditioningProvider(nn.Module): - """Main class to provide conditions given all the supported conditioners. + """Prepare and provide conditions given all the supported conditioners. Args: conditioners (dict): Dictionary of conditioners. - merge_text_conditions_p (float, optional): Probability to merge all text sources - into a single text condition. Defaults to 0. - drop_desc_p (float, optional): Probability to drop the original description - when merging all text sources into a single text condition. Defaults to 0. - device (tp.Union[torch.device, str], optional): Device for conditioners and output condition types. + device (torch.device or str, optional): Device for conditioners and output condition types. """ - def __init__( - self, - conditioners: tp.Dict[str, BaseConditioner], - merge_text_conditions_p: float = 0, - drop_desc_p: float = 0, - device: tp.Union[torch.device, str] = "cpu", - ): + def __init__(self, conditioners: tp.Dict[str, BaseConditioner], device: tp.Union[torch.device, str] = "cpu"): super().__init__() self.device = device - self.merge_text_conditions_p = merge_text_conditions_p - self.drop_desc_p = drop_desc_p self.conditioners = nn.ModuleDict(conditioners) + @property + def joint_embed_conditions(self): + return [m.attribute for m in self.conditioners.values() if isinstance(m, JointEmbeddingConditioner)] + + @property + def has_joint_embed_conditions(self): + return len(self.joint_embed_conditions) > 0 + @property def text_conditions(self): return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)] @@ -769,33 +1159,36 @@ class ConditioningProvider(nn.Module): This will return a dict matching conditioner names to their arbitrary tokenized representations. Args: - inputs (list[ConditioningAttribres]): List of ConditioningAttributes objects containing + inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing text and wav conditions. """ - assert all([type(x) == ConditioningAttributes for x in inputs]), \ - "got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]" \ + assert all([isinstance(x, ConditioningAttributes) for x in inputs]), ( + "Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]", f" but types were {set([type(x) for x in inputs])}" + ) output = {} text = self._collate_text(inputs) wavs = self._collate_wavs(inputs) + joint_embeds = self._collate_joint_embeds(inputs) - assert set(text.keys() | wavs.keys()).issubset(set(self.conditioners.keys())), \ - f"got an unexpected attribute! Expected {self.conditioners.keys()}, got {text.keys(), wavs.keys()}" + assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), ( + f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ", + f"got {text.keys(), wavs.keys(), joint_embeds.keys()}" + ) - for attribute, batch in chain(text.items(), wavs.items()): + for attribute, batch in chain(text.items(), wavs.items(), joint_embeds.items()): output[attribute] = self.conditioners[attribute].tokenize(batch) return output def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]: - """Compute pairs of `(embedding, mask)` using the configured conditioners - and the tokenized representations. The output is for example: - - { - "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])), - "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])), - ... - } + """Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations. + The output is for example: + { + "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])), + "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])), + ... + } Args: tokenized (dict): Dict of tokenized representations as returned by `tokenize()`. @@ -820,51 +1213,22 @@ class ConditioningProvider(nn.Module): "genre": ["Rock", "Hip-hop"], "description": ["A rock song with a guitar solo", "A hip-hop verse"] } - """ - batch_per_attribute: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list) - - def _merge_conds(cond, merge_text_conditions_p=0, drop_desc_p=0): - def is_valid(k, v): - k_valid = k in ['key', 'bpm', 'genre', 'moods', 'instrument'] - v_valid = v is not None and isinstance(v, (int, float, str, list)) - return k_valid and v_valid - - def process_value(v): - if isinstance(v, (int, float, str)): - return v - if isinstance(v, list): - return ", ".join(v) - else: - RuntimeError(f"unknown type for text value! ({type(v), v})") - - desc = cond.text['description'] - meta_data = "" - if random.uniform(0, 1) < merge_text_conditions_p: - meta_pairs = [f'{k}: {process_value(v)}' for k, v in cond.text.items() if is_valid(k, v)] - random.shuffle(meta_pairs) - meta_data = ". ".join(meta_pairs) - desc = desc if not random.uniform(0, 1) < drop_desc_p else None - - if desc is None: - desc = meta_data if len(meta_data) > 1 else None - else: - desc = desc.rstrip('.') + ". " + meta_data - cond.text['description'] = desc.strip() if desc else None - - if self.training and self.merge_text_conditions_p: - for sample in samples: - _merge_conds(sample, self.merge_text_conditions_p, self.drop_desc_p) + Args: + samples (list of ConditioningAttributes): List of ConditioningAttributes samples. + Returns: + dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch. + """ + out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list) texts = [x.text for x in samples] for text in texts: for condition in self.text_conditions: - batch_per_attribute[condition].append(text[condition]) - - return batch_per_attribute + out[condition].append(text[condition]) + return out - def _collate_wavs(self, samples: tp.List[ConditioningAttributes]): + def _collate_wavs(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, WavCondition]: """Generate a dict where the keys are attributes by which we fetch similar wavs, - and the values are Tensors of wavs according to said attribtues. + and the values are Tensors of wavs according to said attributes. *Note*: by the time the samples reach this function, each sample should have some waveform inside the "wav" attribute. It should be either: @@ -873,27 +1237,89 @@ class ConditioningProvider(nn.Module): 3. A null waveform due to it being dropped in a dropout module (nullified by dropout) Args: - samples (tp.List[ConditioningAttributes]): List of ConditioningAttributes samples. + samples (list of ConditioningAttributes): List of ConditioningAttributes samples. Returns: - dict: A dicionary mapping an attribute name to wavs. + dict[str, WavCondition]: A dictionary mapping an attribute name to wavs. """ wavs = defaultdict(list) - lens = defaultdict(list) + lengths = defaultdict(list) + sample_rates = defaultdict(list) paths = defaultdict(list) - out = {} + seek_times = defaultdict(list) + out: tp.Dict[str, WavCondition] = {} for sample in samples: for attribute in self.wav_conditions: - wav, length, path = sample.wav[attribute] - wavs[attribute].append(wav.flatten()) - lens[attribute].append(length) - paths[attribute].append(path) + wav, length, sample_rate, path, seek_time = sample.wav[attribute] + assert wav.dim() == 3, f"Got wav with dim={wav.dim()}, but expected 3 [1, C, T]" + assert wav.size(0) == 1, f"Got wav [B, C, T] with shape={wav.shape}, but expected B == 1" + # mono-channel conditioning + wav = wav.mean(1, keepdim=True) # [1, 1, T] + wavs[attribute].append(wav.flatten()) # [T] + lengths[attribute].append(length) + sample_rates[attribute].extend(sample_rate) + paths[attribute].extend(path) + seek_times[attribute].extend(seek_time) # stack all wavs to a single tensor for attribute in self.wav_conditions: stacked_wav, _ = collate(wavs[attribute], dim=0) - out[attribute] = WavCondition(stacked_wav.unsqueeze(1), - torch.cat(lens['self_wav']), paths[attribute]) # type: ignore + out[attribute] = WavCondition( + stacked_wav.unsqueeze(1), torch.cat(lengths[attribute]), sample_rates[attribute], + paths[attribute], seek_times[attribute]) + + return out + + def _collate_joint_embeds(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, JointEmbedCondition]: + """Generate a dict where the keys are attributes by which we compute joint embeddings, + and the values are Tensors of pre-computed embeddings and the corresponding text attributes. + + Args: + samples (list[ConditioningAttributes]): List of ConditioningAttributes samples. + Returns: + A dictionary mapping an attribute name to joint embeddings. + """ + texts = defaultdict(list) + wavs = defaultdict(list) + lengths = defaultdict(list) + sample_rates = defaultdict(list) + paths = defaultdict(list) + seek_times = defaultdict(list) + channels: int = 0 + + out = {} + for sample in samples: + for attribute in self.joint_embed_conditions: + wav, text, length, sample_rate, path, seek_time = sample.joint_embed[attribute] + assert wav.dim() == 3 + if channels == 0: + channels = wav.size(1) + else: + assert channels == wav.size(1), "not all audio has same number of channels in batch" + assert wav.size(0) == 1, "Expecting single-wav batch in the collate method" + wav = einops.rearrange(wav, "b c t -> (b c t)") # [1, C, T] => [C * T] + wavs[attribute].append(wav) + texts[attribute].extend(text) + lengths[attribute].append(length) + sample_rates[attribute].extend(sample_rate) + paths[attribute].extend(path) + seek_times[attribute].extend(seek_time) + + for attribute in self.joint_embed_conditions: + stacked_texts = texts[attribute] + stacked_paths = paths[attribute] + stacked_seek_times = seek_times[attribute] + stacked_wavs = pad_sequence(wavs[attribute]).to(self.device) + stacked_wavs = einops.rearrange(stacked_wavs, "(c t) b -> b c t", c=channels) + stacked_sample_rates = sample_rates[attribute] + stacked_lengths = torch.cat(lengths[attribute]).to(self.device) + assert stacked_lengths.size(0) == stacked_wavs.size(0) + assert len(stacked_sample_rates) == stacked_wavs.size(0) + assert len(stacked_texts) == stacked_wavs.size(0) + out[attribute] = JointEmbedCondition( + text=stacked_texts, wav=stacked_wavs, + length=stacked_lengths, sample_rate=stacked_sample_rates, + path=stacked_paths, seek_time=stacked_seek_times) return out @@ -920,7 +1346,7 @@ class ConditionFuser(StreamingModule): super().__init__() assert all( [k in self.FUSING_METHODS for k in fuse2cond.keys()] - ), f"got invalid fuse method, allowed methods: {self.FUSING_MEHTODS}" + ), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}" self.cross_attention_pos_emb = cross_attention_pos_emb self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond @@ -931,16 +1357,16 @@ class ConditionFuser(StreamingModule): def forward( self, - input: Tensor, + input: torch.Tensor, conditions: tp.Dict[str, ConditionType] - ) -> tp.Tuple[Tensor, tp.Optional[Tensor]]: + ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: """Fuse the conditions to the provided model input. Args: - input (Tensor): Transformer input. - conditions (tp.Dict[str, ConditionType]): Dict of conditions. + input (torch.Tensor): Transformer input. + conditions (dict[str, ConditionType]): Dict of conditions. Returns: - tp.Tuple[Tensor, Tensor]: The first tensor is the transformer input + tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input after the conditions have been fused. The second output tensor is the tensor used for cross-attention or None if no cross attention inputs exist. """ @@ -959,16 +1385,16 @@ class ConditionFuser(StreamingModule): cross_attention_output = None for cond_type, (cond, cond_mask) in conditions.items(): op = self.cond2fuse[cond_type] - if op == "sum": + if op == 'sum': input += cond - elif op == "input_interpolate": - cond = rearrange(cond, "b t d -> b d t") + elif op == 'input_interpolate': + cond = einops.rearrange(cond, "b t d -> b d t") cond = F.interpolate(cond, size=input.shape[1]) - input += rearrange(cond, "b d t -> b t d") - elif op == "prepend": + input += einops.rearrange(cond, "b d t -> b t d") + elif op == 'prepend': if first_step: input = torch.cat([cond, input], dim=1) - elif op == "cross": + elif op == 'cross': if cross_attention_output is not None: cross_attention_output = torch.cat([cross_attention_output, cond], dim=1) else: diff --git a/audiocraft/modules/conv.py b/audiocraft/modules/conv.py index 972938ab84712eb06e1b10cea25444eee51d6637..d115cbf8729b642ed78608bd00a4d0fd5afae6fd 100644 --- a/audiocraft/modules/conv.py +++ b/audiocraft/modules/conv.py @@ -46,8 +46,7 @@ def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0) -> int: - """See `pad_for_conv1d`. - """ + """See `pad_for_conv1d`.""" length = x.shape[-1] n_frames = (length - kernel_size + padding_total) / stride + 1 ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) @@ -90,8 +89,7 @@ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): - """Remove padding from x, handling properly zero padding. Only for 1d! - """ + """Remove padding from x, handling properly zero padding. Only for 1d!""" padding_left, padding_right = paddings assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) assert (padding_left + padding_right) <= x.shape[-1] @@ -176,8 +174,8 @@ class StreamableConv1d(nn.Module): super().__init__() # warn user on unusual setup between dilation and stride if stride > 1 and dilation > 1: - warnings.warn('StreamableConv1d has been initialized with stride > 1 and dilation > 1' - f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).') + warnings.warn("StreamableConv1d has been initialized with stride > 1 and dilation > 1" + f" (kernel_size={kernel_size} stride={stride}, dilation={dilation}).") self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride, dilation=dilation, groups=groups, bias=bias, causal=causal, norm=norm, norm_kwargs=norm_kwargs) diff --git a/audiocraft/modules/diffusion_schedule.py b/audiocraft/modules/diffusion_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..74ca6e3f2e7c4ff904d96dade315b0b46856778d --- /dev/null +++ b/audiocraft/modules/diffusion_schedule.py @@ -0,0 +1,272 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Functions for Noise Schedule, defines diffusion process, reverse process and data processor. +""" + +from collections import namedtuple +import random +import typing as tp +import julius +import torch + +TrainingItem = namedtuple("TrainingItem", "noisy noise step") + + +def betas_from_alpha_bar(alpha_bar): + alphas = torch.cat([torch.Tensor([alpha_bar[0]]), alpha_bar[1:]/alpha_bar[:-1]]) + return 1 - alphas + + +class SampleProcessor(torch.nn.Module): + def project_sample(self, x: torch.Tensor): + """Project the original sample to the 'space' where the diffusion will happen.""" + return x + + def return_sample(self, z: torch.Tensor): + """Project back from diffusion space to the actual sample space.""" + return z + + +class MultiBandProcessor(SampleProcessor): + """ + MultiBand sample processor. The input audio is splitted across + frequency bands evenly distributed in mel-scale. + + Each band will be rescaled to match the power distribution + of Gaussian noise in that band, using online metrics + computed on the first few samples. + + Args: + n_bands (int): Number of mel-bands to split the signal over. + sample_rate (int): Sample rate of the audio. + num_samples (int): Number of samples to use to fit the rescaling + for each band. The processor won't be stable + until it has seen that many samples. + power_std (float or list/tensor): The rescaling factor computed to match the + power of Gaussian noise in each band is taken to + that power, i.e. `1.` means full correction of the energy + in each band, and values less than `1` means only partial + correction. Can be used to balance the relative importance + of low vs. high freq in typical audio signals. + """ + def __init__(self, n_bands: int = 8, sample_rate: float = 24_000, + num_samples: int = 10_000, power_std: tp.Union[float, tp.List[float], torch.Tensor] = 1.): + super().__init__() + self.n_bands = n_bands + self.split_bands = julius.SplitBands(sample_rate, n_bands=n_bands) + self.num_samples = num_samples + self.power_std = power_std + if isinstance(power_std, list): + assert len(power_std) == n_bands + power_std = torch.tensor(power_std) + self.register_buffer('counts', torch.zeros(1)) + self.register_buffer('sum_x', torch.zeros(n_bands)) + self.register_buffer('sum_x2', torch.zeros(n_bands)) + self.register_buffer('sum_target_x2', torch.zeros(n_bands)) + self.counts: torch.Tensor + self.sum_x: torch.Tensor + self.sum_x2: torch.Tensor + self.sum_target_x2: torch.Tensor + + @property + def mean(self): + mean = self.sum_x / self.counts + return mean + + @property + def std(self): + std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt() + return std + + @property + def target_std(self): + target_std = self.sum_target_x2 / self.counts + return target_std + + def project_sample(self, x: torch.Tensor): + assert x.dim() == 3 + bands = self.split_bands(x) + if self.counts.item() < self.num_samples: + ref_bands = self.split_bands(torch.randn_like(x)) + self.counts += len(x) + self.sum_x += bands.mean(dim=(2, 3)).sum(dim=1) + self.sum_x2 += bands.pow(2).mean(dim=(2, 3)).sum(dim=1) + self.sum_target_x2 += ref_bands.pow(2).mean(dim=(2, 3)).sum(dim=1) + rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size + bands = (bands - self.mean.view(-1, 1, 1, 1)) * rescale.view(-1, 1, 1, 1) + return bands.sum(dim=0) + + def return_sample(self, x: torch.Tensor): + assert x.dim() == 3 + bands = self.split_bands(x) + rescale = (self.std / self.target_std) ** self.power_std + bands = bands * rescale.view(-1, 1, 1, 1) + self.mean.view(-1, 1, 1, 1) + return bands.sum(dim=0) + + +class NoiseSchedule: + """Noise schedule for diffusion. + + Args: + beta_t0 (float): Variance of the first diffusion step. + beta_t1 (float): Variance of the last diffusion step. + beta_exp (float): Power schedule exponent + num_steps (int): Number of diffusion step. + variance (str): choice of the sigma value for the denoising eq. Choices: "beta" or "beta_tilde" + clip (float): clipping value for the denoising steps + rescale (float): rescaling value to avoid vanishing signals unused by default (i.e 1) + repartition (str): shape of the schedule only power schedule is supported + sample_processor (SampleProcessor): Module that normalize data to match better the gaussian distribution + noise_scale (float): Scaling factor for the noise + """ + def __init__(self, beta_t0: float = 1e-4, beta_t1: float = 0.02, num_steps: int = 1000, variance: str = 'beta', + clip: float = 5., rescale: float = 1., device='cuda', beta_exp: float = 1, + repartition: str = "power", alpha_sigmoid: dict = {}, n_bands: tp.Optional[int] = None, + sample_processor: SampleProcessor = SampleProcessor(), noise_scale: float = 1.0, **kwargs): + + self.beta_t0 = beta_t0 + self.beta_t1 = beta_t1 + self.variance = variance + self.num_steps = num_steps + self.clip = clip + self.sample_processor = sample_processor + self.rescale = rescale + self.n_bands = n_bands + self.noise_scale = noise_scale + assert n_bands is None + if repartition == "power": + self.betas = torch.linspace(beta_t0 ** (1 / beta_exp), beta_t1 ** (1 / beta_exp), num_steps, + device=device, dtype=torch.float) ** beta_exp + else: + raise RuntimeError('Not implemented') + self.rng = random.Random(1234) + + def get_beta(self, step: tp.Union[int, torch.Tensor]): + if self.n_bands is None: + return self.betas[step] + else: + return self.betas[:, step] # [n_bands, len(step)] + + def get_initial_noise(self, x: torch.Tensor): + if self.n_bands is None: + return torch.randn_like(x) + return torch.randn((x.size(0), self.n_bands, x.size(2))) + + def get_alpha_bar(self, step: tp.Optional[tp.Union[int, torch.Tensor]] = None) -> torch.Tensor: + """Return 'alpha_bar', either for a given step, or as a tensor with its value for each step.""" + if step is None: + return (1 - self.betas).cumprod(dim=-1) # works for simgle and multi bands + if type(step) is int: + return (1 - self.betas[:step + 1]).prod() + else: + return (1 - self.betas).cumprod(dim=0)[step].view(-1, 1, 1) + + def get_training_item(self, x: torch.Tensor, tensor_step: bool = False) -> TrainingItem: + """Create a noisy data item for diffusion model training: + + Args: + x (torch.Tensor): clean audio data torch.tensor(bs, 1, T) + tensor_step (bool): If tensor_step = false, only one step t is sample, + the whole batch is diffused to the same step and t is int. + If tensor_step = true, t is a tensor of size (x.size(0),) + every element of the batch is diffused to a independently sampled. + """ + step: tp.Union[int, torch.Tensor] + if tensor_step: + bs = x.size(0) + step = torch.randint(0, self.num_steps, size=(bs,), device=x.device) + else: + step = self.rng.randrange(self.num_steps) + alpha_bar = self.get_alpha_bar(step) # [batch_size, n_bands, 1] + + x = self.sample_processor.project_sample(x) + noise = torch.randn_like(x) + noisy = (alpha_bar.sqrt() / self.rescale) * x + (1 - alpha_bar).sqrt() * noise * self.noise_scale + return TrainingItem(noisy, noise, step) + + def generate(self, model: torch.nn.Module, initial: tp.Optional[torch.Tensor] = None, + condition: tp.Optional[torch.Tensor] = None, return_list: bool = False): + """Full ddpm reverse process. + + Args: + model (nn.Module): Diffusion model. + initial (tensor): Initial Noise. + condition (tensor): Input conditionning Tensor (e.g. encodec compressed representation). + return_list (bool): Whether to return the whole process or only the sampled point. + """ + alpha_bar = self.get_alpha_bar(step=self.num_steps - 1) + current = initial + iterates = [initial] + for step in range(self.num_steps)[::-1]: + with torch.no_grad(): + estimate = model(current, step, condition=condition).sample + alpha = 1 - self.betas[step] + previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt() + previous_alpha_bar = self.get_alpha_bar(step=step - 1) + if step == 0: + sigma2 = 0 + elif self.variance == 'beta': + sigma2 = 1 - alpha + elif self.variance == 'beta_tilde': + sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha) + elif self.variance == 'none': + sigma2 = 0 + else: + raise ValueError(f'Invalid variance type {self.variance}') + + if sigma2 > 0: + previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale + if self.clip: + previous = previous.clamp(-self.clip, self.clip) + current = previous + alpha_bar = previous_alpha_bar + if step == 0: + previous *= self.rescale + if return_list: + iterates.append(previous.cpu()) + + if return_list: + return iterates + else: + return self.sample_processor.return_sample(previous) + + def generate_subsampled(self, model: torch.nn.Module, initial: torch.Tensor, step_list: tp.Optional[list] = None, + condition: tp.Optional[torch.Tensor] = None, return_list: bool = False): + """Reverse process that only goes through Markov chain states in step_list.""" + if step_list is None: + step_list = list(range(1000))[::-50] + [0] + alpha_bar = self.get_alpha_bar(step=self.num_steps - 1) + alpha_bars_subsampled = (1 - self.betas).cumprod(dim=0)[list(reversed(step_list))].cpu() + betas_subsampled = betas_from_alpha_bar(alpha_bars_subsampled) + current = initial * self.noise_scale + iterates = [current] + for idx, step in enumerate(step_list[:-1]): + with torch.no_grad(): + estimate = model(current, step, condition=condition).sample * self.noise_scale + alpha = 1 - betas_subsampled[-1 - idx] + previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt() + previous_alpha_bar = self.get_alpha_bar(step_list[idx + 1]) + if step == step_list[-2]: + sigma2 = 0 + previous_alpha_bar = torch.tensor(1.0) + else: + sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha) + if sigma2 > 0: + previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale + if self.clip: + previous = previous.clamp(-self.clip, self.clip) + current = previous + alpha_bar = previous_alpha_bar + if step == 0: + previous *= self.rescale + if return_list: + iterates.append(previous.cpu()) + if return_list: + return iterates + else: + return self.sample_processor.return_sample(previous) diff --git a/audiocraft/modules/rope.py b/audiocraft/modules/rope.py index 4b8c70b9aba28eeb53d12ddc3de8852492847808..c12cee0954f27c45d79627771fdf7fa9fc10dfcc 100644 --- a/audiocraft/modules/rope.py +++ b/audiocraft/modules/rope.py @@ -18,7 +18,7 @@ class XPos(nn.Module): dim (int): Embedding dimension. smoothing (float): Smoothing factor applied to the decay rates. base_scale (int): Base decay rate, given in terms of scaling time. - device (torch.device or None): Device on which to initialize the module. + device (torch.device, optional): Device on which to initialize the module. dtype (torch.dtype): dtype to use to generate the embedding. """ def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int = 512, @@ -36,8 +36,7 @@ class XPos(nn.Module): self.decay: tp.Optional[torch.Tensor] = None def get_decay(self, start: int, end: int): - """Create complex decay tensor, cache values for fast computation. - """ + """Create complex decay tensor, cache values for fast computation.""" if self.decay is None or end > self.decay.shape[0]: assert isinstance(self.decay_rates, torch.Tensor) # Satisfy type checker. idx = torch.arange(end, device=self.decay_rates.device, dtype=self.dtype) @@ -55,7 +54,7 @@ class RotaryEmbedding(nn.Module): max_period (float): Maximum period of the rotation frequencies. xpos (bool): Use xPos, applies an exponential decay to rotation matrix. scale (float): Scale of positional embedding, set to 0 to deactivate. - device (torch.device or None): Device on which to initialize the module. + device (torch.device, optional): Device on which to initialize the module. dtype (torch.dtype): dtype to use to generate the embedding. """ def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool = False, @@ -74,8 +73,7 @@ class RotaryEmbedding(nn.Module): self.xpos = XPos(dim, device=device, dtype=dtype) if xpos else None def get_rotation(self, start: int, end: int): - """Create complex rotation tensor, cache values for fast computation. - """ + """Create complex rotation tensor, cache values for fast computation.""" if self.rotation is None or end > self.rotation.shape[0]: assert isinstance(self.frequencies, torch.Tensor) # Satisfy type checker. idx = torch.arange(end, device=self.frequencies.device, dtype=self.dtype) @@ -83,14 +81,16 @@ class RotaryEmbedding(nn.Module): self.rotation = torch.polar(torch.ones_like(angles), angles) return self.rotation[start:end] - def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool = False): - """Apply rope rotation to query or key tensor. - """ - T = x.shape[1] - rotation = self.get_rotation(start, start + T).unsqueeze(0).unsqueeze(2) + def rotate(self, x: torch.Tensor, start: int = 0, time_dim: int = 1, invert_decay: bool = False): + """Apply rope rotation to query or key tensor.""" + T = x.shape[time_dim] + target_shape = [1] * x.dim() + target_shape[time_dim] = T + target_shape[-1] = -1 + rotation = self.get_rotation(start, start + T).view(target_shape) if self.xpos: - decay = self.xpos.get_decay(start, start + T).unsqueeze(0).unsqueeze(2) + decay = self.xpos.get_decay(start, start + T).view(target_shape) else: decay = 1.0 @@ -99,26 +99,27 @@ class RotaryEmbedding(nn.Module): x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2)) scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale) - x_out = torch.view_as_real(x_complex * scaled_rotation).flatten(-2) + x_out = torch.view_as_real(x_complex * scaled_rotation).view_as(x) return x_out.type_as(x) - def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0): + def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0, time_dim: int = 1): """ Apply rope rotation to both query and key tensors. Supports streaming mode, in which query and key are not expected to have the same shape. - In streaming mode, key will be of legnth [P + C] with P the cached past timesteps, but + In streaming mode, key will be of length [P + C] with P the cached past timesteps, but query will be [C] (typically C == 1). Args: query (torch.Tensor): Query to rotate. key (torch.Tensor): Key to rotate. start (int): Start index of the sequence for time offset. + time_dim (int): which dimension represent the time steps. """ - query_timesteps = query.shape[1] - key_timesteps = key.shape[1] + query_timesteps = query.shape[time_dim] + key_timesteps = key.shape[time_dim] streaming_offset = key_timesteps - query_timesteps - query_out = self.rotate(query, start + streaming_offset) - key_out = self.rotate(key, start, invert_decay=True) + query_out = self.rotate(query, start + streaming_offset, time_dim) + key_out = self.rotate(key, start, time_dim, invert_decay=True) return query_out, key_out diff --git a/audiocraft/modules/streaming.py b/audiocraft/modules/streaming.py index fdbdf5e90fc0c6560873d66bf273460b38e5ed7e..fba06936294ca15d72acd2d44f9dbda39a638107 100644 --- a/audiocraft/modules/streaming.py +++ b/audiocraft/modules/streaming.py @@ -57,8 +57,7 @@ class StreamingModule(nn.Module): @contextmanager def streaming(self): - """Context manager to enter streaming mode. Reset streaming state on exit. - """ + """Context manager to enter streaming mode. Reset streaming state on exit.""" self._set_streaming(True) try: yield @@ -67,16 +66,14 @@ class StreamingModule(nn.Module): self.reset_streaming() def reset_streaming(self): - """Reset the streaming state. - """ + """Reset the streaming state.""" def _reset(name: str, module: StreamingModule): module._streaming_state.clear() self._apply_named_streaming(_reset) def get_streaming_state(self) -> State: - """Return the streaming state, including that of sub-modules. - """ + """Return the streaming state, including that of sub-modules.""" state: State = {} def _add(name: str, module: StreamingModule): @@ -89,8 +86,7 @@ class StreamingModule(nn.Module): return state def set_streaming_state(self, state: State): - """Set the streaming state, including that of sub-modules. - """ + """Set the streaming state, including that of sub-modules.""" state = dict(state) def _set(name: str, module: StreamingModule): diff --git a/audiocraft/modules/transformer.py b/audiocraft/modules/transformer.py index e69cca829d774d0b8b36c0de9b7924373da81b43..691df6a21657ea00f5f3ab0ed6f1cfea444dc746 100644 --- a/audiocraft/modules/transformer.py +++ b/audiocraft/modules/transformer.py @@ -35,8 +35,8 @@ def set_efficient_attention_backend(backend: str = 'torch'): _efficient_attention_backend = backend -def _get_attention_time_dimension() -> int: - if _efficient_attention_backend == 'torch': +def _get_attention_time_dimension(memory_efficient: bool) -> int: + if _efficient_attention_backend == 'torch' and memory_efficient: return 2 else: return 1 @@ -89,11 +89,11 @@ def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1) -def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers""" +def expand_repeated_kv(x: torch.Tensor, n_rep: int, memory_efficient: bool) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers.""" if n_rep == 1: return x - if _efficient_attention_backend == 'torch': + if _efficient_attention_backend == 'torch' and memory_efficient: bs, n_kv_heads, slen, head_dim = x.shape return ( x[:, :, None, :, :] @@ -111,14 +111,14 @@ def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: class LayerScale(nn.Module): """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). - This rescales diagonaly the residual outputs close to 0, with a learnt scale. + This rescales diagonally the residual outputs close to 0, with a learnt scale. Args: channels (int): Number of channels. init (float): Initial scale. channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`. - device (torch.device or None): Device on which to initialize the module. - dtype (torch.dtype or None): dtype to use to initialize the module. + device (torch.device or str, optional): Device on which to initialize the module. + dtype (torch.dtype, optional): dtype to use to initialize the module. """ def __init__(self, channels: int, init: float = 1e-4, channel_last: bool = True, device=None, dtype=None): @@ -144,22 +144,22 @@ class StreamingMultiheadAttention(StreamingModule): dropout (float): Dropout level. bias (bool): Use bias in projections. causal (bool): Causal mask applied automatically. - past_context (int or None): Receptive field for the causal mask, infinite if None. + past_context (int, optional): Receptive field for the causal mask, infinite if None. custom (bool): Use custom MHA implementation, for testing / benchmarking. memory_efficient (bool): Use xformers based memory efficient attention. attention_as_float32 (bool): Perform the attention as float32 (especially important with memory_efficient as autocast won't do this automatically). - rope (`RotaryEmbedding` or None): Rope embedding to use. + rope (`RotaryEmbedding`, optional): Rope embedding to use. cross_attention: Should be true when used as a cross attention. All keys and values must be available at once, streaming is only for the queries. Cannot be used with `causal` or `rope` (as it wouldn't make sens to - intepret the time steps in the keys relative to those in the queries). + interpret the time steps in the keys relative to those in the queries). safe_streaming (bool): Bug fix, will go away with xformers update. qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product. kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads). This will lead to faster decoding time on A100 or other GPUs with tensorcore. - device (torch.device or None): Sevice on which to initialize. - dtype (torch.dtype or None): dtype to use. + device (torch.device, optional): Device on which to initialize. + dtype (torch.dtype, optional): dtype to use. """ def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False, @@ -234,14 +234,14 @@ class StreamingMultiheadAttention(StreamingModule): # Return a causal mask, accounting for potentially stored past keys/values # We actually return a bias for the attention score, as this has the same # convention both in the builtin MHA in Pytorch, and Xformers functions. - time_dim = _get_attention_time_dimension() + time_dim = _get_attention_time_dimension(self.memory_efficient) if self.memory_efficient: from xformers.ops import LowerTriangularMask if current_steps == 1: # If we only have one step, then we do not need a mask. return None elif 'past_keys' in self._streaming_state: - raise RuntimeError('Not supported at the moment') + raise RuntimeError("Not supported at the moment") else: # Then we can safely use a lower triangular mask return LowerTriangularMask() @@ -264,7 +264,7 @@ class StreamingMultiheadAttention(StreamingModule): torch.full([], float('-inf'), device=device, dtype=dtype)) def _complete_kv(self, k, v): - time_dim = _get_attention_time_dimension() + time_dim = _get_attention_time_dimension(self.memory_efficient) if self.cross_attention: # With cross attention we assume all keys and values # are already available, and streaming is with respect @@ -298,8 +298,7 @@ class StreamingMultiheadAttention(StreamingModule): return nk, nv def _apply_rope(self, query: torch.Tensor, key: torch.Tensor): - # TODO: fix and verify layout. - assert _efficient_attention_backend == 'xformers', 'Rope not supported with torch attn.' + time_dim = _get_attention_time_dimension(self.memory_efficient) # Apply rope embeddings to query and key tensors. assert self.rope is not None if 'past_keys' in self._streaming_state: @@ -311,16 +310,16 @@ class StreamingMultiheadAttention(StreamingModule): else: past_context_offset = 0 streaming_offset = past_context_offset + past_keys_offset - return self.rope.rotate_qk(query, key, start=streaming_offset) + return self.rope.rotate_qk(query, key, start=streaming_offset, time_dim=time_dim) def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, key_padding_mask=None, need_weights=False, attn_mask=None, average_attn_weights=True, is_causal=False): assert attn_mask is None - assert not is_causal, ("new param added in torch 2.0.1 not supported, " + assert not is_causal, ("New param added in torch 2.0.1 not supported, " "use the causal args in the constructor.") - time_dim = _get_attention_time_dimension() + time_dim = _get_attention_time_dimension(self.memory_efficient) if time_dim == 2: layout = "b h t d" else: @@ -394,8 +393,8 @@ class StreamingMultiheadAttention(StreamingModule): q, k = self._apply_rope(q, k) k, v = self._complete_kv(k, v) if self.kv_repeat > 1: - k = expand_repeated_kv(k, self.kv_repeat) - v = expand_repeated_kv(v, self.kv_repeat) + k = expand_repeated_kv(k, self.kv_repeat, self.memory_efficient) + v = expand_repeated_kv(v, self.kv_repeat, self.memory_efficient) if self.attention_as_float32: q, k, v = [x.float() for x in [q, k, v]] if self.memory_efficient: @@ -455,7 +454,7 @@ class StreamingTransformerLayer(nn.TransformerEncoderLayer): bias_ff (bool): Use bias for FF. bias_attn (bool): Use bias for MHA. causal (bool): Causal mask applied automatically. - past_context (int or None): Receptive field for the causal mask, infinite if None. + past_context (int, optional): Receptive field for the causal mask, infinite if None. custom (bool): Use custom MHA implementation, for testing / benchmarking. memory_efficient (bool): Use xformers based memory efficient attention. attention_as_float32 (bool): Perform the attention as float32 @@ -465,15 +464,15 @@ class StreamingTransformerLayer(nn.TransformerEncoderLayer): cross_attention (bool): If True, expect to get secondary input for cross-attention. Cross attention will use the default MHA, as it typically won't require special treatment. - layer_scale (float or None): If not None, LayerScale will be used with + layer_scale (float, optional): If not None, LayerScale will be used with the given value as initial scale. - rope (`RotaryEmbedding` or None): Rope embedding to use. - attention_dropout (float or None): If not None, separate the value of the dimension dropout + rope (`RotaryEmbedding`, optional): Rope embedding to use. + attention_dropout (float, optional): If not None, separate the value of the dimension dropout in FFN and of the attention dropout. kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads). This will lead to faster decoding time on A100 or other GPUs with tensorcore. - device (torch.device or None): Device on which to initialize. - dtype (torch.dtype or None): dtype to use. + device (torch.device, optional): Device on which to initialize. + dtype (torch.dtype, optional): dtype to use. **kwargs: See `nn.TransformerEncoderLayer`. """ def __init__(self, d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1, @@ -576,30 +575,30 @@ class StreamingTransformer(StreamingModule): bias_ff (bool): Use bias for FF. bias_attn (bool): Use bias for MHA. causal (bool): Causal mask applied automatically. - past_context (int or None): Receptive field for the causal mask, infinite if None. + past_context (int, optional): Receptive field for the causal mask, infinite if None. custom (bool): Use custom MHA implementation, for testing / benchmarking. memory_efficient (bool): Use xformers based memory efficient attention. attention_as_float32 (bool): Perform the attention as float32 (especially important with memory_efficient as autocast won't do this automatically). cross_attention (bool): If True, expect to get secondary input for cross-attention. - layer_scale (float or None): If not None, LayerScale will be used + layer_scale (float, optional): If not None, LayerScale will be used with the given value as initial scale. positional_embedding (str): Positional embedding strategy (sin, rope, or sin_rope). max_period (float): Maximum period of the time embedding. positional_scale (float): Scale of positional embedding, set to 0 to deactivate. xpos (bool): Apply xpos exponential decay to positional embedding (rope only). - lr (float or None): learning rate override through the `make_optim_group` API. - weight_decay (float or None): Weight_decay override through the `make_optim_group` API. + lr (float, optional): learning rate override through the `make_optim_group` API. + weight_decay (float, optional): Weight_decay override through the `make_optim_group` API. layer_class: (subclass of `StreamingTransformerLayer): class to use - to initialize the layers, allowing further customization outside of Audiocraft. + to initialize the layers, allowing further customization outside of AudioCraft. checkpointing (str): Checkpointing strategy to reduce memory usage. No checkpointing if set to 'none'. Per layer checkpointing using PyTorch if set to 'torch' (entire layer checkpointed, i.e. linears are evaluated twice, minimal memory usage, but maximal runtime). Finally, `xformers_default` provide a policy for opting-out some operations of the checkpointing like linear layers and attention, providing a middle ground between speed and memory. - device (torch.device or None): Device on which to initialize. - dtype (torch.dtype or None): dtype to use. + device (torch.device, optional): Device on which to initialize. + dtype (torch.dtype, optional): dtype to use. **kwargs: See `nn.TransformerEncoderLayer`. """ def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048, @@ -713,7 +712,7 @@ class StreamingTransformer(StreamingModule): return group -# special attention attention related function +# special attention related function def _verify_xformers_memory_efficient_compat(): try: diff --git a/audiocraft/optim/__init__.py b/audiocraft/optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f48c17dfafa9a2be46a91ed1fb64f54c5572a730 --- /dev/null +++ b/audiocraft/optim/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Optimization stuff. In particular, optimizers (DAdaptAdam), schedulers +and Exponential Moving Average. +""" + +# flake8: noqa +from .cosine_lr_scheduler import CosineLRScheduler +from .dadam import DAdaptAdam +from .inverse_sqrt_lr_scheduler import InverseSquareRootLRScheduler +from .linear_warmup_lr_scheduler import LinearWarmupLRScheduler +from .polynomial_decay_lr_scheduler import PolynomialDecayLRScheduler +from .ema import ModuleDictEMA diff --git a/audiocraft/optim/cosine_lr_scheduler.py b/audiocraft/optim/cosine_lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..1e4f0bbf28f1ad893a301f1bfac1da8e97370337 --- /dev/null +++ b/audiocraft/optim/cosine_lr_scheduler.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler + + +class CosineLRScheduler(_LRScheduler): + """Cosine LR scheduler. + + Args: + optimizer (Optimizer): Torch optimizer. + warmup_steps (int): Number of warmup steps. + total_steps (int): Total number of steps. + lr_min_ratio (float): Minimum learning rate. + cycle_length (float): Cycle length. + """ + def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int, + lr_min_ratio: float = 0.0, cycle_length: float = 1.0): + self.warmup_steps = warmup_steps + assert self.warmup_steps >= 0 + self.total_steps = total_steps + assert self.total_steps >= 0 + self.lr_min_ratio = lr_min_ratio + self.cycle_length = cycle_length + super().__init__(optimizer) + + def _get_sched_lr(self, lr: float, step: int): + if step < self.warmup_steps: + lr_ratio = step / self.warmup_steps + lr = lr_ratio * lr + elif step <= self.total_steps: + s = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps) + lr_ratio = self.lr_min_ratio + 0.5 * (1 - self.lr_min_ratio) * \ + (1. + math.cos(math.pi * s / self.cycle_length)) + lr = lr_ratio * lr + else: + lr_ratio = self.lr_min_ratio + lr = lr_ratio * lr + return lr + + def get_lr(self): + return [self._get_sched_lr(lr, self.last_epoch) for lr in self.base_lrs] diff --git a/audiocraft/optim/dadam.py b/audiocraft/optim/dadam.py new file mode 100644 index 0000000000000000000000000000000000000000..e009969f2ba405d621f9dd6cf0fa2c0d4a428f51 --- /dev/null +++ b/audiocraft/optim/dadam.py @@ -0,0 +1,248 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Any + +import torch +import torch.optim +import torch.distributed as dist + + +logger = logging.getLogger(__name__) +_params_t = Any + + +def to_real(x): + if torch.is_complex(x): + return x.real + else: + return x + + +class DAdaptAdam(torch.optim.Optimizer): + """Adam with D-Adaptation automatic step-sizes. + Leave LR set to 1 unless you encounter instability. + + Args: + params (iterable): + Iterable of parameters to optimize or dicts defining parameter groups. + lr (float): + Learning rate adjustment parameter. Increases or decreases the D-adapted learning rate. + betas (tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + momentum (float): + Momentum value in the range [0,1) (default: 0.9). + eps (float): + Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-8). + weight_decay (float): + Weight decay, i.e. a L2 penalty (default: 0). + log_every (int): + Log using print every k steps, default 0 (no logging). + decouple (boolean): + Use AdamW style decoupled weight decay + d0 (float): + Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing. + growth_rate (float): + prevent the D estimate from growing faster than this multiplicative rate. + Default is inf, for unrestricted. Values like 1.02 give a kind of learning + rate warmup effect. + fsdp_in_use (bool): + If you're using sharded parameters, this should be set to True. The optimizer + will attempt to auto-detect this, but if you're using an implementation other + than PyTorch's builtin version, the auto-detection won't work. + """ + def __init__(self, params, lr=1.0, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + log_every=0, + decouple=True, + d0=1e-6, + growth_rate=float('inf')): + if not 0.0 < d0: + raise ValueError("Invalid d0 value: {}".format(d0)) + if not 0.0 < lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 < eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + if decouple: + logger.info("Using decoupled weight decay") + + from .fsdp import is_fsdp_used + fsdp_in_use = is_fsdp_used() + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, + d=d0, + k=0, + gsq_weighted=0.0, + log_every=log_every, + decouple=decouple, + growth_rate=growth_rate, + fsdp_in_use=fsdp_in_use) + + super().__init__(params, defaults) + + @property + def supports_memory_efficient_fp16(self): + return False + + @property + def supports_flat_params(self): + return True + + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + g_sq = 0.0 + sksq_weighted = 0.0 + sk_l1 = 0.0 + + lr = max(group['lr'] for group in self.param_groups) + + group = self.param_groups[0] + gsq_weighted = group['gsq_weighted'] + d = group['d'] + dlr = d*lr + + growth_rate = group['growth_rate'] + decouple = group['decouple'] + fsdp_in_use = group['fsdp_in_use'] + log_every = group['log_every'] + + beta1, beta2 = group['betas'] + + for group in self.param_groups: + group_lr = group['lr'] + decay = group['weight_decay'] + k = group['k'] + eps = group['eps'] + + if group_lr not in [lr, 0.0]: + raise RuntimeError("Setting different lr values in different parameter " + "groups is only supported for values of 0") + + for p in group['params']: + if p.grad is None: + continue + if hasattr(p, "_fsdp_flattened"): + fsdp_in_use = True + grad = p.grad.data + + # Apply weight decay (coupled variant) + if decay != 0 and not decouple: + grad.add_(p.data, alpha=decay) + + state = self.state[p] + + # State initialization + if 'step' not in state: + state['step'] = 0 + state['s'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach() + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach() + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + to_real(p.data), memory_format=torch.preserve_format).detach() + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + + grad_grad = to_real(grad * grad.conj()) + + # Adam EMA updates + if group_lr > 0: + exp_avg.mul_(beta1).add_(grad, alpha=dlr*(1-beta1)) + exp_avg_sq.mul_(beta2).add_(grad_grad, alpha=1-beta2) + + denom = exp_avg_sq.sqrt().add_(eps) + + g_sq += grad_grad.div_(denom).sum().item() + + s = state['s'] + s.mul_(beta2).add_(grad, alpha=dlr*(1-beta2)) + sksq_weighted += to_real(s * s.conj()).div_(denom).sum().item() + sk_l1 += s.abs().sum().item() + + ###### + + gsq_weighted = beta2*gsq_weighted + g_sq*(dlr**2)*(1-beta2) + d_hat = d + + # if we have not done any progres, return + # if we have any gradients available, will have sk_l1 > 0 (unless \|g\|=0) + if sk_l1 == 0: + return loss + + if lr > 0.0: + if fsdp_in_use: + dist_tensor = torch.zeros(3, device='cuda') + dist_tensor[0] = sksq_weighted + dist_tensor[1] = gsq_weighted + dist_tensor[2] = sk_l1 + dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM) + global_sksq_weighted = dist_tensor[0] + global_gsq_weighted = dist_tensor[1] + global_sk_l1 = dist_tensor[2] + else: + global_sksq_weighted = sksq_weighted + global_gsq_weighted = gsq_weighted + global_sk_l1 = sk_l1 + + d_hat = (global_sksq_weighted/(1-beta2) - global_gsq_weighted)/global_sk_l1 + d = max(d, min(d_hat, d*growth_rate)) + + if log_every > 0 and k % log_every == 0: + logger.info( + f"(k={k}) dlr: {dlr:1.1e} d_hat: {d_hat:1.1e}, d: {d:1.8}. " + f"sksq_weighted={global_sksq_weighted:1.1e} gsq_weighted={global_gsq_weighted:1.1e} " + f"sk_l1={global_sk_l1:1.1e}{' (FSDP)' if fsdp_in_use else ''}") + + for group in self.param_groups: + group['gsq_weighted'] = gsq_weighted + group['d'] = d + + group_lr = group['lr'] + decay = group['weight_decay'] + k = group['k'] + eps = group['eps'] + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + + state = self.state[p] + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + + state['step'] += 1 + + denom = exp_avg_sq.sqrt().add_(eps) + denom = denom.type(p.type()) + + # Apply weight decay (decoupled variant) + if decay != 0 and decouple and group_lr > 0: + p.data.add_(p.data, alpha=-decay * dlr) + + # Take step + p.data.addcdiv_(exp_avg, denom, value=-1) + + group['k'] = k + 1 + + return loss diff --git a/audiocraft/optim/ema.py b/audiocraft/optim/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..4337eaff066a8ca124dca3e3e63ee36e417c055c --- /dev/null +++ b/audiocraft/optim/ema.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# ModelEMA implementation is taken from +# https://github.com/facebookresearch/demucs + +from collections import defaultdict +import typing as tp + +import torch +import torch.nn as nn + + +def _get_all_non_persistent_buffers_set(module: nn.Module, root: str = "") -> set: + names: set = set() + for (name, sub_module) in module.named_modules(): + if name == '': + buffer_names = module._non_persistent_buffers_set + buffer_names = {f"{root}.{buff_name}" if len(root) > 0 else buff_name + for buff_name in buffer_names} + names.update(buffer_names) + else: + sub_name = f"{root}.{name}" if len(root) > 0 else name + sub_buffer_names = _get_all_non_persistent_buffers_set(sub_module, sub_name) + names.update(sub_buffer_names) + return names + + +def _get_named_tensors(module: nn.Module): + non_persistent_buffers_set = _get_all_non_persistent_buffers_set(module) + named_buffers = [(name, buffer) for (name, buffer) in module.named_buffers() + if name not in non_persistent_buffers_set] + named_parameters = list(module.named_parameters()) + return named_parameters + named_buffers + + +class ModuleDictEMA: + """Exponential Moving Average over a nn.ModuleDict. + + You can switch to the EMA weights temporarily. + """ + def __init__(self, module_dict: nn.ModuleDict, decay: float = 0.999, + unbias: bool = True, device: tp.Union[torch.device, str] = 'cpu'): + self.decay = decay + self.module_dict = module_dict + self.state: dict = defaultdict(dict) + self.count = 0 + self.device = device + self.unbias = unbias + self._init() + + def _init(self): + for module_name, module in self.module_dict.items(): + for key, val in _get_named_tensors(module): + if not val.is_floating_point(): + continue + device = self.device or val.device + if key not in self.state[module_name]: + self.state[module_name][key] = val.detach().to(device, copy=True) + + def step(self): + if self.unbias: + self.count = self.count * self.decay + 1 + w = 1 / self.count + else: + w = 1 - self.decay + for module_name, module in self.module_dict.items(): + for key, val in _get_named_tensors(module): + if not val.is_floating_point(): + continue + device = self.device or val.device + self.state[module_name][key].mul_(1 - w) + self.state[module_name][key].add_(val.detach().to(device), alpha=w) + + def state_dict(self): + return {'state': self.state, 'count': self.count} + + def load_state_dict(self, state): + self.count = state['count'] + for module_name, module in state['state'].items(): + for key, val in module.items(): + self.state[module_name][key].copy_(val) diff --git a/audiocraft/optim/fsdp.py b/audiocraft/optim/fsdp.py new file mode 100644 index 0000000000000000000000000000000000000000..1090d3d7e60065802f43330512c71cda16d0406a --- /dev/null +++ b/audiocraft/optim/fsdp.py @@ -0,0 +1,195 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Wrapper around FSDP for more convenient use in the training loops. +""" + +from contextlib import contextmanager +import typing as tp +import dora +import torch + +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ( + MixedPrecision, ShardingStrategy, FullStateDictConfig, StateDictType) +from torch.distributed._shard.sharded_tensor.api import ShardedTensor + + +def is_fsdp_used() -> bool: + """Return whether we are using FSDP.""" + # A bit of a hack but should work from anywhere. + if dora.is_xp(): + cfg = dora.get_xp().cfg + if hasattr(cfg, 'fsdp'): + return cfg.fsdp.use + return False + + +def is_sharded_tensor(x: tp.Any) -> bool: + return isinstance(x, ShardedTensor) + + +@contextmanager +def switch_to_full_state_dict(models: tp.List[FSDP]): + # Another bug in FSDP makes it that we cannot use the `state_dict_type` API, + # so let's do thing manually. + for model in models: + FSDP.set_state_dict_type( # type: ignore + model, StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True)) + try: + yield + finally: + for model in models: + FSDP.set_state_dict_type(model, StateDictType.LOCAL_STATE_DICT) # type: ignore + + +def wrap_with_fsdp(cfg, model: torch.nn.Module, + block_classes: tp.Optional[tp.Set[tp.Type]] = None) -> FSDP: + """Wraps a model with FSDP.""" + # Some of the typing is disabled until this gets integrated + # into the stable version of PyTorch. + from torch.distributed.fsdp.wrap import ModuleWrapPolicy # type: ignore + + # we import this here to prevent circular import. + from ..modules.transformer import StreamingTransformerLayer + from ..modules.conditioners import ConditioningProvider + + _fix_post_backward_hook() + + assert cfg.use + sharding_strategy_dict = { + "no_shard": ShardingStrategy.NO_SHARD, + "shard_grad_op": ShardingStrategy.SHARD_GRAD_OP, + "full_shard": ShardingStrategy.FULL_SHARD, + } + + dtype_dict = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + } + + mixed_precision_config = MixedPrecision( + param_dtype=dtype_dict[cfg.param_dtype], + reduce_dtype=dtype_dict[cfg.reduce_dtype], + buffer_dtype=dtype_dict[cfg.buffer_dtype], + ) + + sharding_strategy_config = sharding_strategy_dict[cfg.sharding_strategy] + # The following is going to require being a bit smart + # when doing LM, because this would flush the weights for every time step + # during generation. One possiblity is to use hybrid sharding: + # See: https://pytorch.org/docs/master/fsdp.html#torch.distributed.fsdp.ShardingStrategy + assert sharding_strategy_config != ShardingStrategy.FULL_SHARD, \ + "Not supported at the moment, requires a bit more work." + + local_rank = dora.distrib.get_distrib_spec().local_rank + assert local_rank < torch.cuda.device_count(), "Please upgrade Dora!" + + auto_wrap_policy = None + if block_classes is None: + block_classes = {StreamingTransformerLayer, ConditioningProvider} + if cfg.per_block: + auto_wrap_policy = ModuleWrapPolicy(block_classes) + wrapped = _FSDPFixStateDict( + model, + sharding_strategy=sharding_strategy_config, + mixed_precision=mixed_precision_config, + device_id=local_rank, + sync_module_states=True, + use_orig_params=True, + auto_wrap_policy=auto_wrap_policy, + ) # type: ignore + FSDP.set_state_dict_type(wrapped, StateDictType.LOCAL_STATE_DICT) # type: ignore + + # Let the wrapped model know about the wrapping! + # We use __dict__ to avoid it going into the state dict. + # This is a bit dirty, but needed during generation, as otherwise + # the wrapped model would call itself and bypass FSDP. + for module in FSDP.fsdp_modules(wrapped): + original = module._fsdp_wrapped_module + original.__dict__['_fsdp'] = module + return wrapped + + +def purge_fsdp(model: FSDP): + """Purge the FSDP cached shard inside the model. This should + allow setting the best state or switching to the EMA. + """ + from torch.distributed.fsdp._runtime_utils import _reshard # type: ignore + for module in FSDP.fsdp_modules(model): + handles = module._handles + if not handles: + continue + handle = handles[0] + unsharded_flat_param = handle._get_padded_unsharded_flat_param() + storage_size: int = unsharded_flat_param._typed_storage()._size() # type: ignore + if storage_size == 0: + continue + true_list = [True for h in handles] + _reshard(module, handles, true_list) + + +class _FSDPFixStateDict(FSDP): + @staticmethod + def _name_without_fsdp_prefix(name: str) -> str: + from torch.distributed.fsdp._common_utils import FSDP_WRAPPED_MODULE # type: ignore + parts = name.split('.') + new_parts = [part for part in parts if part != FSDP_WRAPPED_MODULE] + return '.'.join(new_parts) + + def state_dict(self, *args, **kwargs) -> tp.Dict[str, tp.Any]: # type: ignore + state = dict(super().state_dict(*args, **kwargs)) + for key, value in list(state.items()): + if is_sharded_tensor(value): + del state[key] + return state + + def load_state_dict(self, state: tp.Dict[str, tp.Any]): # type: ignore + if self._state_dict_type is StateDictType.FULL_STATE_DICT: + super().load_state_dict(state) + purge_fsdp(self) + return + # Fix FSDP load state dict in all situation. + # Use this only with LOCAL_STATE_DICT !!! + current_state = dict(super().state_dict()) + for key, value in state.items(): + key = _FSDPFixStateDict._name_without_fsdp_prefix(key) + if key not in current_state: + # Emulate strict loading manually. + raise RuntimeError(f"Unknown state key {key}") + current_state[key].copy_(value) + + # Purging cached weights from previous forward. + purge_fsdp(self) + + +_hook_fixed = False + + +def _fix_post_backward_hook(): + global _hook_fixed + if _hook_fixed: + return + _hook_fixed = True + + from torch.distributed.fsdp import _runtime_utils + from torch.distributed.fsdp._common_utils import TrainingState, HandleTrainingState + old_hook = _runtime_utils._post_backward_hook + + def _post_backward_hook(state, handle, *args, **kwargs): + checkpointed = getattr(state._fsdp_wrapped_module, '_audiocraft_checkpointed', False) + if checkpointed: + # there will be one more forward in the backward with checkpointing and that will + # massively confuse FSDP, so we have to make it think everything + # is going according to the plan. + state.training_state = TrainingState.FORWARD_BACKWARD + handle._training_state = HandleTrainingState.BACKWARD_PRE + old_hook(state, handle, *args, **kwargs) + + _runtime_utils._post_backward_hook = _post_backward_hook diff --git a/audiocraft/optim/inverse_sqrt_lr_scheduler.py b/audiocraft/optim/inverse_sqrt_lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..920192e8842c5635bf6f7f76618fa4a6f4b0114a --- /dev/null +++ b/audiocraft/optim/inverse_sqrt_lr_scheduler.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import typing as tp + +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler + + +class InverseSquareRootLRScheduler(_LRScheduler): + """Inverse square root LR scheduler. + + Args: + optimizer (Optimizer): Torch optimizer. + warmup_steps (int): Number of warmup steps. + warmup_init_lr (tp.Optional[float]): Initial learning rate + during warmup phase. When not set, use the provided learning rate. + """ + def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_init_lr: tp.Optional[float] = 0): + self.warmup_steps = warmup_steps + self.warmup_init_lr = warmup_init_lr + super().__init__(optimizer) + + def _get_sched_lr(self, lr: float, step: int): + if step < self.warmup_steps: + warmup_init_lr = self.warmup_init_lr or 0 + lr_step = (lr - warmup_init_lr) / self.warmup_steps + lr = warmup_init_lr + step * lr_step + else: + decay_factor = lr * self.warmup_steps**0.5 + lr = decay_factor * step**-0.5 + return lr + + def get_lr(self): + return [self._get_sched_lr(base_lr, self._step_count) for base_lr in self.base_lrs] diff --git a/audiocraft/optim/linear_warmup_lr_scheduler.py b/audiocraft/optim/linear_warmup_lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..03274a1ae52b6f20473973b77619f34b2bddd6a1 --- /dev/null +++ b/audiocraft/optim/linear_warmup_lr_scheduler.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import typing as tp + +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler + + +class LinearWarmupLRScheduler(_LRScheduler): + """Inverse square root LR scheduler. + + Args: + optimizer (Optimizer): Torch optimizer. + warmup_steps (int): Number of warmup steps. + warmup_init_lr (tp.Optional[float]): Initial learning rate + during warmup phase. When not set, use the provided learning rate. + """ + def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_init_lr: tp.Optional[float] = 0): + self.warmup_steps = warmup_steps + self.warmup_init_lr = warmup_init_lr + super().__init__(optimizer) + + def _get_sched_lr(self, lr: float, step: int): + if step < self.warmup_steps: + warmup_init_lr = self.warmup_init_lr or 0 + lr_step = (lr - warmup_init_lr) / self.warmup_steps + lr = warmup_init_lr + step * lr_step + return lr + + def get_lr(self): + return [self._get_sched_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs] diff --git a/audiocraft/optim/polynomial_decay_lr_scheduler.py b/audiocraft/optim/polynomial_decay_lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..c5ea30b094538269dbb0055ab3163f84d1cf6e90 --- /dev/null +++ b/audiocraft/optim/polynomial_decay_lr_scheduler.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler + + +class PolynomialDecayLRScheduler(_LRScheduler): + """Polynomial decay LR scheduler. + + Args: + optimizer (Optimizer): Torch optimizer. + warmup_steps (int): Number of warmup steps. + total_steps (int): Total number of steps. + end_lr (float): Final learning rate to achieve over total number of steps. + zero_lr_warmup_steps (int): Number of steps with a learning rate of value 0. + power (float): Decay exponent. + """ + def __init__(self, optimizer: Optimizer, warmup_steps: int, total_steps: int, + end_lr: float = 0., zero_lr_warmup_steps: int = 0, power: float = 1.): + self.warmup_steps = warmup_steps + self.total_steps = total_steps + self.end_lr = end_lr + self.zero_lr_warmup_steps = zero_lr_warmup_steps + self.power = power + super().__init__(optimizer) + + def _get_sched_lr(self, lr: float, step: int): + if self.zero_lr_warmup_steps > 0 and step <= self.zero_lr_warmup_steps: + lr = 0 + elif self.warmup_steps > 0 and step <= self.warmup_steps + self.zero_lr_warmup_steps: + lr_ratio = (step - self.zero_lr_warmup_steps) / float(self.warmup_steps) + lr = lr_ratio * lr + elif step >= self.total_steps: + lr = self.end_lr + else: + total_warmup_steps = self.warmup_steps + self.zero_lr_warmup_steps + lr_range = lr - self.end_lr + pct_remaining = 1 - (step - total_warmup_steps) / (self.total_steps - total_warmup_steps) + lr = lr_range * pct_remaining ** self.power + self.end_lr + return lr + + def get_lr(self): + return [self._get_sched_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs] diff --git a/audiocraft/quantization/__init__.py b/audiocraft/quantization/__init__.py index 836d6eb518978480c6b95d6f29ce4f84a9428793..1e0c7e429ab96d67be667e23bf7a0ffa389c036b 100644 --- a/audiocraft/quantization/__init__.py +++ b/audiocraft/quantization/__init__.py @@ -3,7 +3,7 @@ # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. - +"""RVQ.""" # flake8: noqa from .vq import ResidualVectorQuantizer from .base import BaseQuantizer, DummyQuantizer, QuantizedResult diff --git a/audiocraft/quantization/base.py b/audiocraft/quantization/base.py index 1b16c130d266fbd021d3fc29bb9f98c33dd3c588..a77fefb98e62a5bbc6385910261ffdde2ffa5a25 100644 --- a/audiocraft/quantization/base.py +++ b/audiocraft/quantization/base.py @@ -38,30 +38,25 @@ class BaseQuantizer(nn.Module): raise NotImplementedError() def encode(self, x: torch.Tensor) -> torch.Tensor: - """Encode a given input tensor with the specified sample rate at the given bandwidth. - """ + """Encode a given input tensor with the specified sample rate at the given bandwidth.""" raise NotImplementedError() def decode(self, codes: torch.Tensor) -> torch.Tensor: - """Decode the given codes to the quantized representation. - """ + """Decode the given codes to the quantized representation.""" raise NotImplementedError() @property def total_codebooks(self): - """Total number of codebooks. - """ + """Total number of codebooks.""" raise NotImplementedError() @property def num_codebooks(self): - """Number of active codebooks. - """ + """Number of active codebooks.""" raise NotImplementedError() def set_num_codebooks(self, n: int): - """Set the number of active codebooks. - """ + """Set the number of active codebooks.""" raise NotImplementedError() @@ -91,17 +86,14 @@ class DummyQuantizer(BaseQuantizer): @property def total_codebooks(self): - """Total number of codebooks. - """ + """Total number of codebooks.""" return 1 @property def num_codebooks(self): - """Total number of codebooks. - """ + """Total number of codebooks.""" return self.total_codebooks def set_num_codebooks(self, n: int): - """Set the number of active codebooks. - """ + """Set the number of active codebooks.""" raise AttributeError("Cannot override the number of codebooks for the dummy quantizer") diff --git a/audiocraft/quantization/core_vq.py b/audiocraft/quantization/core_vq.py index e1896bb1788a945a1f7be6369abb255ecf72c7a0..da02a6ce3a7de15353f0fba9e826052beb67c436 100644 --- a/audiocraft/quantization/core_vq.py +++ b/audiocraft/quantization/core_vq.py @@ -75,7 +75,7 @@ def kmeans(samples, num_clusters: int, num_iters: int = 10): return means, bins -def orthgonal_loss_fn(t): +def orthogonal_loss_fn(t): # eq (2) from https://arxiv.org/abs/2112.00384 n = t.shape[0] normed_codes = l2norm(t) @@ -237,7 +237,7 @@ class VectorQuantization(nn.Module): orthogonal_reg_weight (float): Orthogonal regularization weights. orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes. orthogonal_reg_max_codes (optional int): Maximum number of codes to consider - for orthogonal regulariation. + for orthogonal regularization. threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes that have an exponential moving average cluster size less than the specified threshold with randomly selected vector from the current batch. @@ -340,7 +340,7 @@ class VectorQuantization(nn.Module): rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes] codebook = codebook[rand_ids] - orthogonal_reg_loss = orthgonal_loss_fn(codebook) + orthogonal_reg_loss = orthogonal_loss_fn(codebook) loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight quantize = self.project_out(quantize) diff --git a/audiocraft/quantization/vq.py b/audiocraft/quantization/vq.py index f67c3a0cd30d4b8993a36c587f00dc8a451d926f..aa57bea59db95ddae35e0657f723ca3a29ee943b 100644 --- a/audiocraft/quantization/vq.py +++ b/audiocraft/quantization/vq.py @@ -30,7 +30,7 @@ class ResidualVectorQuantizer(BaseQuantizer): orthogonal_reg_weight (float): Orthogonal regularization weights. orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes. orthogonal_reg_max_codes (optional int): Maximum number of codes to consider. - for orthogonal regulariation. + for orthogonal regularization. """ def __init__( self, @@ -96,8 +96,7 @@ class ResidualVectorQuantizer(BaseQuantizer): return codes def decode(self, codes: torch.Tensor) -> torch.Tensor: - """Decode the given codes to the quantized representation. - """ + """Decode the given codes to the quantized representation.""" # codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T]. codes = codes.transpose(0, 1) quantized = self.vq.decode(codes) diff --git a/audiocraft/solvers/__init__.py b/audiocraft/solvers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae19f3a8c51abf469697d6affa91449d668716ba --- /dev/null +++ b/audiocraft/solvers/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +Solvers. A Solver is a training recipe, combining the dataloaders, models, +optimizer, losses etc into a single convenient object. +""" + +# flake8: noqa +from .audiogen import AudioGenSolver +from .builders import get_solver +from .base import StandardSolver +from .compression import CompressionSolver +from .musicgen import MusicGenSolver +from .diffusion import DiffusionSolver diff --git a/audiocraft/solvers/audiogen.py b/audiocraft/solvers/audiogen.py new file mode 100644 index 0000000000000000000000000000000000000000..1568f97fe7b84b90c7ef760ef5606fe0a475545a --- /dev/null +++ b/audiocraft/solvers/audiogen.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from . import builders, musicgen + + +class AudioGenSolver(musicgen.MusicGenSolver): + """Solver for AudioGen re-implementation training task. + + Note that this implementation does not strictly follows + the method proposed in https://arxiv.org/abs/2209.15352 + but is derived from MusicGen's training pipeline. + + More information can be found in the AudioGen model card. + """ + DATASET_TYPE: builders.DatasetType = builders.DatasetType.SOUND diff --git a/audiocraft/solvers/base.py b/audiocraft/solvers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..0432e44a36838c5731711f9d54f81822b21f20bd --- /dev/null +++ b/audiocraft/solvers/base.py @@ -0,0 +1,631 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +from contextlib import contextmanager +from pathlib import Path +import typing as tp + +import flashy +import omegaconf +import torch +from torch import nn + +from .. import optim +from ..optim import fsdp +from ..utils import checkpoint +from ..utils.autocast import TorchAutocast +from ..utils.best_state import BestStateDictManager +from ..utils.deadlock import DeadlockDetect +from ..utils.profiler import Profiler +from ..utils.utils import copy_state, dict_from_config, model_hash, with_rank_rng + + +class StandardSolver(ABC, flashy.BaseSolver): + """Standard solver for AudioCraft. + + The standard solver implements a base training loop with the following stages: + train, valid, evaluate and generate that are expected to be all defined for + solvers in AudioCraft. It also provides a nice default management of Dora history replay, + checkpoint management across epoch, and logging configuration. + + AudioCraft solvers must inherit from the StandardSolver and define the methods + associated to each stage as well as the show, build_model and build_dataloaders methods. + """ + def __init__(self, cfg: omegaconf.DictConfig): + super().__init__() + self.logger.info(f"Instantiating solver {self.__class__.__name__} for XP {self.xp.sig}") + self.logger.info(f"All XP logs are stored in {self.xp.folder}") + self.cfg = cfg + self.device = cfg.device + self.model: nn.Module + self._continue_best_source_keys = ['best_state', 'fsdp_best_state'] + self._fsdp_modules: tp.List[fsdp.FSDP] = [] + self._ema_sources: nn.ModuleDict = nn.ModuleDict() + self.ema: tp.Optional[optim.ModuleDictEMA] = None + self.dataloaders: tp.Dict[str, torch.utils.data.DataLoader] = dict() + self._log_updates = self.cfg.logging.get('log_updates', 10) + if self.cfg.logging.log_tensorboard: + self.init_tensorboard(**self.cfg.get('tensorboard')) + if self.cfg.logging.log_wandb and self: + self.init_wandb(**self.cfg.get('wandb')) + # keep a copy of the best performing state for stateful objects + # used for evaluation and generation stages + dtype_best: tp.Optional[torch.dtype] = None + if self.cfg.fsdp.use: + dtype_best = getattr(torch, self.cfg.fsdp.param_dtype) # type: ignore + assert isinstance(dtype_best, torch.dtype) + elif self.cfg.autocast: + dtype_best = getattr(torch, self.cfg.autocast_dtype) # type: ignore + assert isinstance(dtype_best, torch.dtype) + self.best_state: BestStateDictManager = BestStateDictManager(dtype=dtype_best) + # Hacky support for keeping a copy of the full best state in rank0. + self.fsdp_best_state: tp.Dict[str, tp.Any] = {} + self.register_stateful('best_state', 'fsdp_best_state') # register best_state object to keep it in state_dict + self._new_best_state: bool = False # should save a new checkpoint + # instantiate datasets and appropriate number of updates per epoch + self.build_dataloaders() + if self.cfg.execute_only is None: + assert 'train' in self.dataloaders, "The train dataset split must be provided." + assert 'valid' in self.dataloaders, "The valid dataset split must be provided." + self.train_updates_per_epoch = len(self.dataloaders['train']) if 'train' in self.dataloaders else 0 + if self.cfg.optim.updates_per_epoch: + self.train_updates_per_epoch = self.cfg.optim.updates_per_epoch + self.total_updates = self.train_updates_per_epoch * self.cfg.optim.epochs + # instantiate model & exponential moving average on the model + self.build_model() + self.logger.info("Model hash: %s", model_hash(self.model)) + assert 'model' in self.stateful.sources, \ + "Please register the model to stateful with self.register_stateful('model') in build_model." + self.profiler = Profiler(self.model, **self.cfg.profiler) + self.initialize_ema() + self.register_stateful('ema') + assert self.ema is None or 'ema' in self.stateful.sources, \ + "Please register the ema to stateful with self.register_stateful('ema') in build_model." + self.deadlock_detect = DeadlockDetect(**self.cfg.deadlock) + # basic statistics on the trained model + model_size = sum(p.numel() for p in self.model.parameters() if p.requires_grad) / 1e6 + # one copy of grad, one copy of momentum, one copy of denominator and model weights. + # and 4 bytes for each float! + mem_usage = model_size * 4 * 4 / 1000 + self.logger.info("Model size: %.2f M params", model_size) + self.logger.info("Base memory usage, with model, grad and optim: %.2f GB", mem_usage) + + @property + def autocast(self): + """Convenient autocast (or not) using the solver configuration.""" + return TorchAutocast(enabled=self.cfg.autocast, device_type=self.device, dtype=self.autocast_dtype) + + def _get_state_source(self, name) -> flashy.state.StateDictSource: + # Internal utility to get a state source from the solver + return self.stateful.sources[name] + + @property + def best_metric_name(self) -> tp.Optional[str]: + """Metric name used to identify the best state. This metric should be stored in the metrics + used on the stage for best state identification (most likely, `valid`). If None, then + no best state is saved. + """ + return None + + def register_best_state(self, *args: str): + """Register state sources in `BestStateDictManager` to keep their best states along with their + latest states. The best state will be used at evaluation stages instead of the latest states. + + Shortcut around `BestStateDictManager.register` method. You can pass any number of + attribute, included nested attributes and those will be included into the checkpoints + and automatically restored when `BaseSolver.restore` is called. + """ + for name in args: + state_source = self._get_state_source(name) + assert name in self.stateful.sources, "Registered states in best should be registered in stateful first!" + self.best_state.register(name, state_source) + + def register_ema(self, *args: str): + """Register state sources for exponential moving average. + + The registered sources are used to instantiate a ModuleDictEMA instance. + The ModuleDictEMA keeps a `nn.ModuleDict` module that is updated when self.ema.step() is called + and swapped with the original state sources with self.swap_ema_state() method. + + Usage: + self.register_ema('model') + """ + assert self.ema is None, "Cannot register state source to already instantiated EMA." + for name in args: + self._ema_sources[name] = getattr(self, name) + + def wrap_with_fsdp(self, model: torch.nn.Module, *args, **kwargs): + model = fsdp.wrap_with_fsdp(self.cfg.fsdp, model, *args, **kwargs) + if isinstance(model, fsdp.FSDP): + self._fsdp_modules.append(model) + return model + + def update_best_state_from_stage(self, stage_name: str = 'valid'): + """Update latest best state based on pending metrics of a given stage. This method relies + on the `BestStateDictManager.update` method to update the best state_dict with latest weights + if the registered states happen to match to the best performing setup. + """ + if self.best_metric_name is None: + # when no best metric is defined, the last state is always the best + self._new_best_state = True + self.logger.info("Updating best state with current state.") + else: + assert stage_name in self._pending_metrics, f"Metrics for stage {stage_name} not found." + assert self.best_metric_name in self._pending_metrics[stage_name], \ + f"Best metric not found in {stage_name} metrics. Cannot register best state" + current_score = self._pending_metrics[stage_name][self.best_metric_name] + all_best_metric_scores = [ + past_metrics[stage_name][self.best_metric_name] + for past_metrics in self.history + ] + all_best_metric_scores.append(current_score) + best_score = min(all_best_metric_scores) + self._new_best_state = current_score == best_score + if self._new_best_state: + old_best = min(all_best_metric_scores[:-1] + [float('inf')]) + self.logger.info( + f"New best state with {self.best_metric_name}={current_score:.3f} (was {old_best:.3f})") + + if self._new_best_state: + if self.cfg.fsdp.use: + # this will give an empty state dict on all ranks but the rank 0 + # which will have a copy in memory of the full model. + with fsdp.switch_to_full_state_dict(self._fsdp_modules): + for name in self.best_state.states.keys(): + state_source = self._get_state_source(name) + self.best_state.update(name, state_source) + # we save to a different dict. + self.fsdp_best_state.update(self.best_state.state_dict()) + # We cannot efficiently load fsdp_best_state when using FSDP, + # so we have do do a second pass, with the local shards. + for name in self.best_state.states.keys(): + state_source = self._get_state_source(name) + self.best_state.update(name, state_source) + + def _load_new_state_dict(self, state_dict: dict) -> dict: + old_states = {} + for name, new_state in state_dict.items(): + state_source = self._get_state_source(name) + old_states[name] = copy_state(state_source.state_dict()) + state_source.load_state_dict(new_state) + return old_states + + @contextmanager + def swap_best_state(self): + self.logger.debug(f"Swapping to best state for: {', '.join(self.best_state.state_dict().keys())}") + old_states = self._load_new_state_dict(self.best_state.state_dict()) + try: + yield + finally: + self.logger.debug("Swapping back from best to original state") + for name, old_state in old_states.items(): + state_source = self._get_state_source(name) + state_source.load_state_dict(old_state) + + @contextmanager + def swap_ema_state(self): + if self.ema is None: + yield + else: + ema_state_dict = self.ema.state_dict()['state'] + self.logger.debug(f"Swapping to EMA state for: {', '.join(ema_state_dict.keys())}") + old_states = self._load_new_state_dict(ema_state_dict) + try: + yield + finally: + self.logger.debug("Swapping back from EMA state to original state") + for name, old_state in old_states.items(): + state_source = self._get_state_source(name) + state_source.load_state_dict(old_state) + + @property + def is_training(self): + return self.current_stage == 'train' + + def log_model_summary(self, model: nn.Module): + """Log model summary, architecture and size of the model.""" + self.logger.info(model) + mb = sum(p.numel() for p in model.parameters()) * 4 / 2 ** 20 + self.logger.info("Size: %.1f MB", mb) + + @abstractmethod + def build_model(self): + """Method to implement to initialize model.""" + ... + + def initialize_ema(self): + """Initialize exponential moving average with the registered sources. + EMA object is created if the optim.ema.model.decay value is non-null. + """ + from .builders import get_ema + self.ema = get_ema(self._ema_sources, self.cfg.optim.ema) + if self.ema is None: + self.logger.info('No EMA on the model.') + else: + assert self.cfg.optim.ema.updates > 0 + self.logger.info( + f'Initializing EMA on the model with decay = {self.ema.decay}' + f' every {self.cfg.optim.ema.updates} updates' + ) + + @abstractmethod + def build_dataloaders(self): + """Method to implement to initialize dataloaders.""" + ... + + @abstractmethod + def show(self): + """Method to log any information without running the job.""" + ... + + @property + def log_updates(self): + # convenient access to log updates + return self._log_updates + + def checkpoint_path(self, **kwargs): + kwargs.setdefault('use_fsdp', self.cfg.fsdp.use) + return self.folder / checkpoint.checkpoint_name(**kwargs) + + def epoch_checkpoint_path(self, epoch: int, **kwargs): + kwargs.setdefault('use_fsdp', self.cfg.fsdp.use) + return self.folder / checkpoint.checkpoint_name(str(epoch), **kwargs) + + def checkpoint_path_with_name(self, name: str, **kwargs): + kwargs.setdefault('use_fsdp', self.cfg.fsdp.use) + return self.folder / checkpoint.checkpoint_name(name=name, **kwargs) + + def save_checkpoints(self): + """Save checkpoint, optionally keeping a copy for a given epoch.""" + is_sharded = self.cfg.fsdp.use + if not flashy.distrib.is_rank_zero() and not is_sharded: + return + self.logger.info("Model hash: %s", model_hash(self.model)) + state = self.state_dict() + epoch = self.epoch - 1 # pushing metrics will increase the epoch in Flashy, so we do -1 here + + # save minimal state_dict as new checkpoint every X epoch + if self.cfg.checkpoint.save_every: + if epoch % self.cfg.checkpoint.save_every == 0: + minimal_state = state + if self.cfg.checkpoint.keep_every_states is not None and len(self.cfg.checkpoint.keep_every_states) > 0: + minimal_state = { + name: source for name, source in state.items() + if name in self.cfg.checkpoint.keep_every_states + } + epoch_checkpoint_path = self.epoch_checkpoint_path(epoch) + checkpoint.save_checkpoint(minimal_state, epoch_checkpoint_path, is_sharded) + + # save checkpoint as latest checkpoint + if self.cfg.checkpoint.save_last: + last_checkpoint_path = self.checkpoint_path() + checkpoint.save_checkpoint(state, last_checkpoint_path, is_sharded) + + # flush any stale checkpoint to reduce disk footprint + checkpoint.flush_stale_checkpoints(self.checkpoint_path()) + + def load_from_pretrained(self, name: str) -> dict: + raise NotImplementedError("Solver does not provide a way to load pretrained models.") + + def load_checkpoints(self, load_best: bool = False, ignore_state_keys: tp.List[str] = []) -> tp.Optional[dict]: + """Load last checkpoint or the one specified in continue_from. + + Args: + load_best (bool): Whether to load from best state dict or not. + Best state dict is always used when not loading the current xp. + ignore_state_keys (list of str): List of sources to ignore when loading the state, e.g. `optimizer`. + Returns: + state (dict, optional): The loaded state dictionary. + """ + # load checkpoints from xp folder or cfg.continue_from + is_sharded = self.cfg.fsdp.use + load_from_path: tp.Optional[Path] = None + checkpoint_source: tp.Optional[checkpoint.CheckpointSource] = None + + if load_best: + self.logger.info("Trying to load state_dict from best state.") + + state: tp.Optional[dict] = None + rank0_checkpoint_path = self.checkpoint_path(use_fsdp=False) + current_checkpoint_path = self.checkpoint_path() + _pretrained_prefix = '//pretrained/' + continue_pretrained = (self.cfg.continue_from or '').startswith(_pretrained_prefix) + if rank0_checkpoint_path.exists(): + self.logger.info(f"Loading existing checkpoint: {current_checkpoint_path}") + load_from_path = current_checkpoint_path + checkpoint.check_sharded_checkpoint(current_checkpoint_path, rank0_checkpoint_path) + checkpoint_source = checkpoint.CheckpointSource.CURRENT_XP + elif self.cfg.continue_from and not continue_pretrained: + self.logger.info(f"Continuing from provided checkpoint: {self.cfg.continue_from}") + # we're always continuing from consolidated checkpoints: self.cfg.use_fsdp and not continue_best + load_from_path = checkpoint.resolve_checkpoint_path(self.cfg.continue_from, use_fsdp=False) + if load_from_path is None: + self.logger.error('Could not resolve the continue_from checkpoint %s', self.cfg.continue_from) + raise RuntimeError(f'Could not resolve continue_from checkpoint {self.cfg.continue_from}') + checkpoint_source = checkpoint.CheckpointSource.OTHER + + if load_from_path is not None: + state = checkpoint.load_checkpoint(load_from_path, is_sharded) + elif continue_pretrained: + self.logger.info("Loading a pretrained model. Ignoring 'load_best' and 'ignore_state_keys' params.") + state = self.load_from_pretrained(self.cfg.continue_from[len(_pretrained_prefix):]) + checkpoint_source = checkpoint.CheckpointSource.PRETRAINED + load_best = True + + # checkpoints are not from the current xp, we only retrieve the best state + if checkpoint_source is not None and checkpoint_source != checkpoint.CheckpointSource.CURRENT_XP: + assert state is not None + self.logger.info("Checkpoint source is not the current xp: Load state_dict from best state.") + load_best = True + state = {key: state[key] for key in self._continue_best_source_keys if key in state} + # loaded checkpoints are FSDP checkpoints: we're reading the best state + # from FSDP and we drop the regular best_state + if 'fsdp_best_state' in state and state['fsdp_best_state']: + state.pop('best_state', None) + self.logger.info("... Loaded checkpoint has FSDP best state") + # FSDP is enabled in the solver, if the loaded checkpoints do not have FSDP support + # then we're initializing FSDP best state with the regular best state + elif self.cfg.fsdp.use: + if 'fsdp_best_state' not in state or not state['fsdp_best_state']: + # we swap non-FSDP checkpoints best_state to FSDP-compatible best state + state['fsdp_best_state'] = state.pop('best_state') + self.logger.info("... Loaded checkpoint does not have FSDP best state. Use regular best state") + + if state is not None: + if load_best: + self.logger.info("Ignoring keys when loading best %r", ignore_state_keys) + for key in set(ignore_state_keys): + if key in state: + state.pop(key) + has_best_state = 'best_state' in state or 'fsdp_best_state' in state + assert has_best_state, ("Trying to load best state but neither 'best_state'", + " or 'fsdp_best_state' found in checkpoints.") + self.load_state_dict(state) + + # for FSDP, let's make extra sure nothing bad happened with out of sync + # checkpoints across workers. + epoch = float(self.epoch) + avg_epoch = flashy.distrib.average_metrics({'epoch': epoch})['epoch'] + if avg_epoch != epoch: + raise RuntimeError( + f"Inconsistent loading of checkpoints happened, our epoch is {epoch} " + f"but average of epochs is {avg_epoch}, at least one gpu must have a " + "different epoch number.") + + # on load_best, properly reinitialize state_dict, best states and ema + # otherwise we load from the current xp and don't alter anything + if load_best: + self.logger.info("Loading state_dict from best state.") + if not self.cfg.fsdp.use and self.fsdp_best_state: + # loading from an FSDP checkpoint but with FSDP deactivated + self.logger.info("... Loading from FSDP best state dict.") + self.best_state.load_state_dict(self.fsdp_best_state) + + # if load_best, we permanently override the regular state_dict with the best state + if self.cfg.fsdp.use: + self.logger.info("FSDP is used, loading from FSDP best state.") + with fsdp.switch_to_full_state_dict(self._fsdp_modules): + # this might be really fragile but okay for now. + self.load_state_dict(self.fsdp_best_state) + else: + # we permanently swap the stateful objects to their best state + self._load_new_state_dict(self.best_state.state_dict()) + + # the EMA modules should also be instantiated with best state. + # the easiest way to do so is to reinitialize a new EMA with best state loaded. + if self.ema is not None: + self.logger.info("Re-initializing EMA from best state") + self.initialize_ema() + + if self.cfg.fsdp.use: + self.logger.info("Re-initializing best state after using FSDP best state.") + for name in self.best_state.states.keys(): + state_source = self._get_state_source(name) + self.best_state.update(name, state_source) + + return state + + def restore(self, load_best: bool = False, replay_metrics: bool = False, + ignore_state_keys: tp.List[str] = []) -> bool: + """Restore the status of a solver for a given xp. + + Args: + load_best (bool): if `True`, load the best state from the checkpoint. + replay_metrics (bool): if `True`, logs all the metrics from past epochs. + ignore_state_keys (list of str): list of sources to ignore when loading the state, e.g. `optimizer`. + """ + self.logger.info("Restoring weights and history.") + restored_checkpoints = self.load_checkpoints(load_best, ignore_state_keys) + + self.logger.info("Model hash: %s", model_hash(self.model)) + + if replay_metrics and len(self.history) > 0: + self.logger.info("Replaying past metrics...") + for epoch, stages in enumerate(self.history): + for stage_name, metrics in stages.items(): + # We manually log the metrics summary to the result logger + # as we don't want to add them to the pending metrics + self.result_logger._log_summary(stage_name, metrics, step=epoch + 1, step_name='epoch', + formatter=self.get_formatter(stage_name)) + return restored_checkpoints is not None + + def commit(self, save_checkpoints: bool = True): + """Commit metrics to dora and save checkpoints at the end of an epoch.""" + # we override commit to introduce more complex checkpoint saving behaviors + self.history.append(self._pending_metrics) # This will increase self.epoch + if save_checkpoints: + self.save_checkpoints() + self._start_epoch() + if flashy.distrib.is_rank_zero(): + self.xp.link.update_history(self.history) + + def run_epoch(self): + """Run a single epoch with all stages. + + Metrics for a given stage are stored in _pending_metrics and committed by the solver afterwards. + Children solvers can extend this method with custom behavior, e.g.: + + def run_epoch(self): + ... # custom code + super().run_epoch() + ... # custom code + """ + self.run_stage('train', self.train) + with torch.no_grad(): + with self.swap_ema_state(): + self.run_stage('valid', self.valid) + # the best state is updated with EMA states if available + self.update_best_state_from_stage('valid') + with self.swap_best_state(): + if self.should_run_stage('evaluate'): + self.run_stage('evaluate', self.evaluate) + if self.should_run_stage('generate'): + self.run_stage('generate', with_rank_rng()(self.generate)) + + def run(self): + """Training loop.""" + assert len(self.state_dict()) > 0 + self.restore(replay_metrics=True) # load checkpoint and replay history + self.log_hyperparams(dict_from_config(self.cfg)) + for epoch in range(self.epoch, self.cfg.optim.epochs + 1): + if self.should_stop_training(): + return + self.run_epoch() + # Commit will send the metrics to Dora and save checkpoints by default. + self.commit() + + def should_stop_training(self) -> bool: + """Check whether we should stop training or not.""" + return self.epoch > self.cfg.optim.epochs + + def should_run_stage(self, stage_name) -> bool: + """Check whether we want to run the specified stages.""" + stage_every = self.cfg[stage_name].get('every', None) + is_last_epoch = self.epoch == self.cfg.optim.epochs + is_epoch_every = (stage_every and self.epoch % stage_every == 0) + return is_last_epoch or is_epoch_every + + @abstractmethod + def run_step(self, idx: int, batch: tp.Any, metrics: dict): + """Perform one training or valid step on a given batch.""" + ... + + def common_train_valid(self, dataset_split: str, **kwargs: tp.Any): + """Common logic for train and valid stages.""" + self.model.train(self.is_training) + + loader = self.dataloaders[dataset_split] + # get a different order for distributed training, otherwise this will get ignored + if flashy.distrib.world_size() > 1 \ + and isinstance(loader.sampler, torch.utils.data.distributed.DistributedSampler): + loader.sampler.set_epoch(self.epoch) + updates_per_epoch = self.train_updates_per_epoch if self.is_training else len(loader) + if self.cfg.benchmark_no_load: + self.logger.warning("Fake loading for benchmarking: re-using first batch") + batch = next(iter(loader)) + loader = [batch] * updates_per_epoch # type: ignore + lp = self.log_progress(self.current_stage, loader, total=updates_per_epoch, updates=self.log_updates) + average = flashy.averager() # epoch wise average + instant_average = flashy.averager() # average between two logging + metrics: dict = {} + + with self.profiler, self.deadlock_detect: # profiler will only run for the first 20 updates. + for idx, batch in enumerate(lp): + self.deadlock_detect.update('batch') + if idx >= updates_per_epoch: + break + metrics = {} + metrics = self.run_step(idx, batch, metrics) + self.deadlock_detect.update('step') + # run EMA step + if self.ema is not None and self.is_training and (idx + 1) % self.cfg.optim.ema.updates == 0: + self.logger.debug("EMA model step") + self.ema.step() + self.deadlock_detect.update('ema') + self.profiler.step() + instant_metrics = instant_average(metrics) + if lp.update(**instant_metrics): + instant_average = flashy.averager() # reset averager between two logging + metrics = average(metrics) # epoch wise average + self.deadlock_detect.update('end_batch') + + metrics = flashy.distrib.average_metrics(metrics, updates_per_epoch) + return metrics + + def train(self): + """Train stage.""" + return self.common_train_valid('train') + + def valid(self): + """Valid stage.""" + return self.common_train_valid('valid') + + @abstractmethod + def evaluate(self): + """Evaluate stage.""" + ... + + @abstractmethod + def generate(self): + """Generate stage.""" + ... + + def run_one_stage(self, stage_name: str): + """Run only the specified stage. + This method is useful to only generate samples from a trained experiment + or rerun the validation or evaluation stages. + """ + fn = { + 'generate': with_rank_rng()(self.generate), + 'evaluate': self.evaluate, + 'valid': self.valid, + } + if stage_name not in fn: + raise ValueError(f'Trying to run stage {stage_name} is not supported.') + assert len(self.state_dict()) > 0 + self._start_epoch() + with torch.no_grad(), self.swap_best_state(): + self.run_stage(stage_name, fn[stage_name]) + if not self.cfg.execute_inplace: + self.commit(save_checkpoints=False) + + @staticmethod + def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None, + device: tp.Optional[str] = None, autocast: bool = True, + batch_size: tp.Optional[int] = None, + override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None, + **kwargs): + """Mostly a convenience function around audiocraft.train.get_solver_from_sig, + populating all the proper param, deactivating EMA, FSDP, loading the best state, + basically all you need to get a solver ready to "play" with in single GPU mode + and with minimal memory overhead. + + Args: + sig (str): signature to load. + dtype (str or None): potential dtype, as a string, i.e. 'float16'. + device (str or None): potential device, as a string, i.e. 'cuda'. + override_cfg (dict or omegaconf.DictConfig or None): potential device, as a string, i.e. 'cuda'. + """ + from audiocraft import train + our_override_cfg: tp.Dict[str, tp.Any] = {'optim': {'ema': {'use': False}}} + our_override_cfg['autocast'] = autocast + if dtype is not None: + our_override_cfg['dtype'] = dtype + if device is not None: + our_override_cfg['device'] = device + if batch_size is not None: + our_override_cfg['dataset'] = {'batch_size': batch_size} + if override_cfg is None: + override_cfg = {} + override_cfg = omegaconf.OmegaConf.merge( + omegaconf.DictConfig(override_cfg), omegaconf.DictConfig(our_override_cfg)) # type: ignore + solver = train.get_solver_from_sig( + sig, override_cfg=override_cfg, + load_best=True, disable_fsdp=True, + ignore_state_keys=['optimizer', 'ema'], **kwargs) + solver.model.eval() + return solver diff --git a/audiocraft/solvers/builders.py b/audiocraft/solvers/builders.py new file mode 100644 index 0000000000000000000000000000000000000000..304d8f08d33a70e8be9388c855b2ae43bdf2683b --- /dev/null +++ b/audiocraft/solvers/builders.py @@ -0,0 +1,363 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +All the functions to build the relevant solvers and used objects +from the Hydra config. +""" + +from enum import Enum +import logging +import typing as tp + +import dora +import flashy +import omegaconf +import torch +from torch import nn +from torch.optim import Optimizer +# LRScheduler was renamed in some torch versions +try: + from torch.optim.lr_scheduler import LRScheduler # type: ignore +except ImportError: + from torch.optim.lr_scheduler import _LRScheduler as LRScheduler + +from .base import StandardSolver +from .. import adversarial, data, losses, metrics, optim +from ..utils.utils import dict_from_config, get_loader + + +logger = logging.getLogger(__name__) + + +class DatasetType(Enum): + AUDIO = "audio" + MUSIC = "music" + SOUND = "sound" + + +def get_solver(cfg: omegaconf.DictConfig) -> StandardSolver: + """Instantiate solver from config.""" + from .audiogen import AudioGenSolver + from .compression import CompressionSolver + from .musicgen import MusicGenSolver + from .diffusion import DiffusionSolver + klass = { + 'compression': CompressionSolver, + 'musicgen': MusicGenSolver, + 'audiogen': AudioGenSolver, + 'lm': MusicGenSolver, # backward compatibility + 'diffusion': DiffusionSolver, + 'sound_lm': AudioGenSolver, # backward compatibility + }[cfg.solver] + return klass(cfg) # type: ignore + + +def get_optim_parameter_groups(model: nn.Module): + """Create parameter groups for the model using the appropriate method + if defined for each modules, to create the different groups. + + Args: + model (nn.Module): torch model + Returns: + List of parameter groups + """ + seen_params: tp.Set[nn.parameter.Parameter] = set() + other_params = [] + groups = [] + for name, module in model.named_modules(): + if hasattr(module, 'make_optim_group'): + group = module.make_optim_group() + params = set(group['params']) + assert params.isdisjoint(seen_params) + seen_params |= set(params) + groups.append(group) + for param in model.parameters(): + if param not in seen_params: + other_params.append(param) + groups.insert(0, {'params': other_params}) + parameters = groups + return parameters + + +def get_optimizer(params: tp.Union[nn.Module, tp.Iterable[torch.Tensor]], cfg: omegaconf.DictConfig) -> Optimizer: + """Build torch optimizer from config and set of parameters. + Supported optimizers: Adam, AdamW + + Args: + params (nn.Module or iterable of torch.Tensor): Parameters to optimize. + cfg (DictConfig): Optimization-related configuration. + Returns: + torch.optim.Optimizer. + """ + if 'optimizer' not in cfg: + if getattr(cfg, 'optim', None) is not None: + raise KeyError("Optimizer not found in config. Try instantiating optimizer from cfg.optim?") + else: + raise KeyError("Optimizer not found in config.") + + parameters = get_optim_parameter_groups(params) if isinstance(params, nn.Module) else params + optimizer: torch.optim.Optimizer + if cfg.optimizer == 'adam': + optimizer = torch.optim.Adam(parameters, lr=cfg.lr, **cfg.adam) + elif cfg.optimizer == 'adamw': + optimizer = torch.optim.AdamW(parameters, lr=cfg.lr, **cfg.adam) + elif cfg.optimizer == 'dadam': + optimizer = optim.DAdaptAdam(parameters, lr=cfg.lr, **cfg.adam) + else: + raise ValueError(f"Unsupported LR Scheduler: {cfg.lr_scheduler}") + return optimizer + + +def get_lr_scheduler(optimizer: torch.optim.Optimizer, + cfg: omegaconf.DictConfig, + total_updates: int) -> tp.Optional[LRScheduler]: + """Build torch learning rate scheduler from config and associated optimizer. + Supported learning rate schedulers: ExponentialLRScheduler, PlateauLRScheduler + + Args: + optimizer (torch.optim.Optimizer): Optimizer. + cfg (DictConfig): Schedule-related configuration. + total_updates (int): Total number of updates. + Returns: + torch.optim.Optimizer. + """ + if 'lr_scheduler' not in cfg: + raise KeyError("LR Scheduler not found in config") + + lr_sched: tp.Optional[LRScheduler] = None + if cfg.lr_scheduler == 'step': + lr_sched = torch.optim.lr_scheduler.StepLR(optimizer, **cfg.step) + elif cfg.lr_scheduler == 'exponential': + lr_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=cfg.exponential) + elif cfg.lr_scheduler == 'cosine': + kwargs = dict_from_config(cfg.cosine) + warmup_steps = kwargs.pop('warmup') + lr_sched = optim.CosineLRScheduler( + optimizer, warmup_steps=warmup_steps, total_steps=total_updates, **kwargs) + elif cfg.lr_scheduler == 'polynomial_decay': + kwargs = dict_from_config(cfg.polynomial_decay) + warmup_steps = kwargs.pop('warmup') + lr_sched = optim.PolynomialDecayLRScheduler( + optimizer, warmup_steps=warmup_steps, total_steps=total_updates, **kwargs) + elif cfg.lr_scheduler == 'inverse_sqrt': + kwargs = dict_from_config(cfg.inverse_sqrt) + warmup_steps = kwargs.pop('warmup') + lr_sched = optim.InverseSquareRootLRScheduler(optimizer, warmup_steps=warmup_steps, **kwargs) + elif cfg.lr_scheduler == 'linear_warmup': + kwargs = dict_from_config(cfg.linear_warmup) + warmup_steps = kwargs.pop('warmup') + lr_sched = optim.LinearWarmupLRScheduler(optimizer, warmup_steps=warmup_steps, **kwargs) + elif cfg.lr_scheduler is not None: + raise ValueError(f"Unsupported LR Scheduler: {cfg.lr_scheduler}") + return lr_sched + + +def get_ema(module_dict: nn.ModuleDict, cfg: omegaconf.DictConfig) -> tp.Optional[optim.ModuleDictEMA]: + """Initialize Exponential Moving Average. + + Args: + module_dict (nn.ModuleDict): ModuleDict for which to compute the EMA. + cfg (omegaconf.DictConfig): Optim EMA configuration. + Returns: + optim.ModuleDictEMA: EMA version of the ModuleDict. + """ + kw: tp.Dict[str, tp.Any] = dict(cfg) + use = kw.pop('use', False) + decay = kw.pop('decay', None) + device = kw.pop('device', None) + if not use: + return None + if len(module_dict) == 0: + raise ValueError("Trying to build EMA but an empty module_dict source is provided!") + ema_module = optim.ModuleDictEMA(module_dict, decay=decay, device=device) + return ema_module + + +def get_loss(loss_name: str, cfg: omegaconf.DictConfig): + """Instantiate loss from configuration.""" + klass = { + 'l1': torch.nn.L1Loss, + 'l2': torch.nn.MSELoss, + 'mel': losses.MelSpectrogramL1Loss, + 'mrstft': losses.MRSTFTLoss, + 'msspec': losses.MultiScaleMelSpectrogramLoss, + 'sisnr': losses.SISNR, + }[loss_name] + kwargs = dict(getattr(cfg, loss_name)) + return klass(**kwargs) + + +def get_balancer(loss_weights: tp.Dict[str, float], cfg: omegaconf.DictConfig) -> losses.Balancer: + """Instantiate loss balancer from configuration for the provided weights.""" + kwargs: tp.Dict[str, tp.Any] = dict_from_config(cfg) + return losses.Balancer(loss_weights, **kwargs) + + +def get_adversary(name: str, cfg: omegaconf.DictConfig) -> nn.Module: + """Initialize adversary from config.""" + klass = { + 'msd': adversarial.MultiScaleDiscriminator, + 'mpd': adversarial.MultiPeriodDiscriminator, + 'msstftd': adversarial.MultiScaleSTFTDiscriminator, + }[name] + adv_cfg: tp.Dict[str, tp.Any] = dict(getattr(cfg, name)) + return klass(**adv_cfg) + + +def get_adversarial_losses(cfg) -> nn.ModuleDict: + """Initialize dict of adversarial losses from config.""" + device = cfg.device + adv_cfg = getattr(cfg, 'adversarial') + adversaries = adv_cfg.get('adversaries', []) + adv_loss_name = adv_cfg['adv_loss'] + feat_loss_name = adv_cfg.get('feat_loss') + normalize = adv_cfg.get('normalize', True) + feat_loss: tp.Optional[adversarial.FeatureMatchingLoss] = None + if feat_loss_name: + assert feat_loss_name in ['l1', 'l2'], f"Feature loss only support L1 or L2 but {feat_loss_name} found." + loss = get_loss(feat_loss_name, cfg) + feat_loss = adversarial.FeatureMatchingLoss(loss, normalize) + loss = adversarial.get_adv_criterion(adv_loss_name) + loss_real = adversarial.get_real_criterion(adv_loss_name) + loss_fake = adversarial.get_fake_criterion(adv_loss_name) + adv_losses = nn.ModuleDict() + for adv_name in adversaries: + adversary = get_adversary(adv_name, cfg).to(device) + optimizer = get_optimizer(adversary.parameters(), cfg.optim) + adv_loss = adversarial.AdversarialLoss( + adversary, + optimizer, + loss=loss, + loss_real=loss_real, + loss_fake=loss_fake, + loss_feat=feat_loss, + normalize=normalize + ) + adv_losses[adv_name] = adv_loss + return adv_losses + + +def get_visqol(cfg: omegaconf.DictConfig) -> metrics.ViSQOL: + """Instantiate ViSQOL metric from config.""" + kwargs = dict_from_config(cfg) + return metrics.ViSQOL(**kwargs) + + +def get_fad(cfg: omegaconf.DictConfig) -> metrics.FrechetAudioDistanceMetric: + """Instantiate Frechet Audio Distance metric from config.""" + kwargs = dict_from_config(cfg.tf) + xp = dora.get_xp() + kwargs['log_folder'] = xp.folder + return metrics.FrechetAudioDistanceMetric(**kwargs) + + +def get_kldiv(cfg: omegaconf.DictConfig) -> metrics.KLDivergenceMetric: + """Instantiate KL-Divergence metric from config.""" + kld_metrics = { + 'passt': metrics.PasstKLDivergenceMetric, + } + klass = kld_metrics[cfg.model] + kwargs = dict_from_config(cfg.get(cfg.model)) + return klass(**kwargs) + + +def get_text_consistency(cfg: omegaconf.DictConfig) -> metrics.TextConsistencyMetric: + """Instantiate Text Consistency metric from config.""" + text_consistency_metrics = { + 'clap': metrics.CLAPTextConsistencyMetric + } + klass = text_consistency_metrics[cfg.model] + kwargs = dict_from_config(cfg.get(cfg.model)) + return klass(**kwargs) + + +def get_chroma_cosine_similarity(cfg: omegaconf.DictConfig) -> metrics.ChromaCosineSimilarityMetric: + """Instantiate Chroma Cosine Similarity metric from config.""" + assert cfg.model == 'chroma_base', "Only support 'chroma_base' method for chroma cosine similarity metric" + kwargs = dict_from_config(cfg.get(cfg.model)) + return metrics.ChromaCosineSimilarityMetric(**kwargs) + + +def get_audio_datasets(cfg: omegaconf.DictConfig, + dataset_type: DatasetType = DatasetType.AUDIO) -> tp.Dict[str, torch.utils.data.DataLoader]: + """Build AudioDataset from configuration. + + Args: + cfg (omegaconf.DictConfig): Configuration. + dataset_type: The type of dataset to create. + Returns: + dict[str, torch.utils.data.DataLoader]: Map of dataloader for each data split. + """ + dataloaders: dict = {} + + sample_rate = cfg.sample_rate + channels = cfg.channels + seed = cfg.seed + max_sample_rate = cfg.datasource.max_sample_rate + max_channels = cfg.datasource.max_channels + + assert cfg.dataset is not None, "Could not find dataset definition in config" + + dataset_cfg = dict_from_config(cfg.dataset) + splits_cfg: dict = {} + splits_cfg['train'] = dataset_cfg.pop('train') + splits_cfg['valid'] = dataset_cfg.pop('valid') + splits_cfg['evaluate'] = dataset_cfg.pop('evaluate') + splits_cfg['generate'] = dataset_cfg.pop('generate') + execute_only_stage = cfg.get('execute_only', None) + + for split, path in cfg.datasource.items(): + if not isinstance(path, str): + continue # skipping this as not a path + if execute_only_stage is not None and split != execute_only_stage: + continue + logger.info(f"Loading audio data split {split}: {str(path)}") + assert ( + cfg.sample_rate <= max_sample_rate + ), f"Expecting a max sample rate of {max_sample_rate} for datasource but {sample_rate} found." + assert ( + cfg.channels <= max_channels + ), f"Expecting a max number of channels of {max_channels} for datasource but {channels} found." + + split_cfg = splits_cfg[split] + split_kwargs = {k: v for k, v in split_cfg.items()} + kwargs = {**dataset_cfg, **split_kwargs} # split kwargs overrides default dataset_cfg + kwargs['sample_rate'] = sample_rate + kwargs['channels'] = channels + + if kwargs.get('permutation_on_files') and cfg.optim.updates_per_epoch: + kwargs['num_samples'] = ( + flashy.distrib.world_size() * cfg.dataset.batch_size * cfg.optim.updates_per_epoch) + + num_samples = kwargs['num_samples'] + shuffle = kwargs['shuffle'] + + return_info = kwargs.pop('return_info') + batch_size = kwargs.pop('batch_size', None) + num_workers = kwargs.pop('num_workers') + + if dataset_type == DatasetType.MUSIC: + dataset = data.music_dataset.MusicDataset.from_meta(path, **kwargs) + elif dataset_type == DatasetType.SOUND: + dataset = data.sound_dataset.SoundDataset.from_meta(path, **kwargs) + elif dataset_type == DatasetType.AUDIO: + dataset = data.info_audio_dataset.InfoAudioDataset.from_meta(path, return_info=return_info, **kwargs) + else: + raise ValueError(f"Dataset type is unsupported: {dataset_type}") + + loader = get_loader( + dataset, + num_samples, + batch_size=batch_size, + num_workers=num_workers, + seed=seed, + collate_fn=dataset.collater if return_info else None, + shuffle=shuffle, + ) + dataloaders[split] = loader + + return dataloaders diff --git a/audiocraft/solvers/compression.py b/audiocraft/solvers/compression.py new file mode 100644 index 0000000000000000000000000000000000000000..b757503472a3bfbf90e1636999e64913848a7474 --- /dev/null +++ b/audiocraft/solvers/compression.py @@ -0,0 +1,328 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import multiprocessing +from pathlib import Path +import typing as tp + +import flashy +import omegaconf +import torch +from torch import nn + +from . import base, builders +from .. import models, quantization +from ..utils import checkpoint +from ..utils.samples.manager import SampleManager +from ..utils.utils import get_pool_executor + + +logger = logging.getLogger(__name__) + + +class CompressionSolver(base.StandardSolver): + """Solver for compression task. + + The compression task combines a set of perceptual and objective losses + to train an EncodecModel (composed of an encoder-decoder and a quantizer) + to perform high fidelity audio reconstruction. + """ + def __init__(self, cfg: omegaconf.DictConfig): + super().__init__(cfg) + self.rng: torch.Generator # set at each epoch + self.adv_losses = builders.get_adversarial_losses(self.cfg) + self.aux_losses = nn.ModuleDict() + self.info_losses = nn.ModuleDict() + assert not cfg.fsdp.use, "FSDP not supported by CompressionSolver." + loss_weights = dict() + for loss_name, weight in self.cfg.losses.items(): + if loss_name in ['adv', 'feat']: + for adv_name, _ in self.adv_losses.items(): + loss_weights[f'{loss_name}_{adv_name}'] = weight + elif weight > 0: + self.aux_losses[loss_name] = builders.get_loss(loss_name, self.cfg) + loss_weights[loss_name] = weight + else: + self.info_losses[loss_name] = builders.get_loss(loss_name, self.cfg) + self.balancer = builders.get_balancer(loss_weights, self.cfg.balancer) + self.register_stateful('adv_losses') + + @property + def best_metric_name(self) -> tp.Optional[str]: + # best model is the last for the compression model + return None + + def build_model(self): + """Instantiate model and optimizer.""" + # Model and optimizer + self.model = models.builders.get_compression_model(self.cfg).to(self.device) + self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim) + self.register_stateful('model', 'optimizer') + self.register_best_state('model') + self.register_ema('model') + + def build_dataloaders(self): + """Instantiate audio dataloaders for each stage.""" + self.dataloaders = builders.get_audio_datasets(self.cfg) + + def show(self): + """Show the compression model and employed adversarial loss.""" + self.logger.info(f"Compression model with {self.model.quantizer.total_codebooks} codebooks:") + self.log_model_summary(self.model) + self.logger.info("Adversarial loss:") + self.log_model_summary(self.adv_losses) + self.logger.info("Auxiliary losses:") + self.logger.info(self.aux_losses) + self.logger.info("Info losses:") + self.logger.info(self.info_losses) + + def run_step(self, idx: int, batch: torch.Tensor, metrics: dict): + """Perform one training or valid step on a given batch.""" + x = batch.to(self.device) + y = x.clone() + + qres = self.model(x) + assert isinstance(qres, quantization.QuantizedResult) + y_pred = qres.x + # Log bandwidth in kb/s + metrics['bandwidth'] = qres.bandwidth.mean() + + if self.is_training: + d_losses: dict = {} + if len(self.adv_losses) > 0 and torch.rand(1, generator=self.rng).item() <= 1 / self.cfg.adversarial.every: + for adv_name, adversary in self.adv_losses.items(): + disc_loss = adversary.train_adv(y_pred, y) + d_losses[f'd_{adv_name}'] = disc_loss + metrics['d_loss'] = torch.sum(torch.stack(list(d_losses.values()))) + metrics.update(d_losses) + + balanced_losses: dict = {} + other_losses: dict = {} + + # penalty from quantization + if qres.penalty is not None and qres.penalty.requires_grad: + other_losses['penalty'] = qres.penalty # penalty term from the quantizer + + # adversarial losses + for adv_name, adversary in self.adv_losses.items(): + adv_loss, feat_loss = adversary(y_pred, y) + balanced_losses[f'adv_{adv_name}'] = adv_loss + balanced_losses[f'feat_{adv_name}'] = feat_loss + + # auxiliary losses + for loss_name, criterion in self.aux_losses.items(): + loss = criterion(y_pred, y) + balanced_losses[loss_name] = loss + + # weighted losses + metrics.update(balanced_losses) + metrics.update(other_losses) + metrics.update(qres.metrics) + + if self.is_training: + # backprop losses that are not handled by balancer + other_loss = torch.tensor(0., device=self.device) + if 'penalty' in other_losses: + other_loss += other_losses['penalty'] + if other_loss.requires_grad: + other_loss.backward(retain_graph=True) + ratio1 = sum(p.grad.data.norm(p=2).pow(2) + for p in self.model.parameters() if p.grad is not None) + assert isinstance(ratio1, torch.Tensor) + metrics['ratio1'] = ratio1.sqrt() + + # balancer losses backward, returns effective training loss + # with effective weights at the current batch. + metrics['g_loss'] = self.balancer.backward(balanced_losses, y_pred) + # add metrics corresponding to weight ratios + metrics.update(self.balancer.metrics) + ratio2 = sum(p.grad.data.norm(p=2).pow(2) + for p in self.model.parameters() if p.grad is not None) + assert isinstance(ratio2, torch.Tensor) + metrics['ratio2'] = ratio2.sqrt() + + # optim + flashy.distrib.sync_model(self.model) + if self.cfg.optim.max_norm: + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.cfg.optim.max_norm + ) + self.optimizer.step() + self.optimizer.zero_grad() + + # informative losses only + info_losses: dict = {} + with torch.no_grad(): + for loss_name, criterion in self.info_losses.items(): + loss = criterion(y_pred, y) + info_losses[loss_name] = loss + + metrics.update(info_losses) + + # aggregated GAN losses: this is useful to report adv and feat across different adversarial loss setups + adv_losses = [loss for loss_name, loss in metrics.items() if loss_name.startswith('adv')] + if len(adv_losses) > 0: + metrics['adv'] = torch.sum(torch.stack(adv_losses)) + feat_losses = [loss for loss_name, loss in metrics.items() if loss_name.startswith('feat')] + if len(feat_losses) > 0: + metrics['feat'] = torch.sum(torch.stack(feat_losses)) + + return metrics + + def run_epoch(self): + # reset random seed at the beginning of the epoch + self.rng = torch.Generator() + self.rng.manual_seed(1234 + self.epoch) + # run epoch + super().run_epoch() + + def evaluate(self): + """Evaluate stage. Runs audio reconstruction evaluation.""" + self.model.eval() + evaluate_stage_name = str(self.current_stage) + + loader = self.dataloaders['evaluate'] + updates = len(loader) + lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates) + average = flashy.averager() + + pendings = [] + ctx = multiprocessing.get_context('spawn') + with get_pool_executor(self.cfg.evaluate.num_workers, mp_context=ctx) as pool: + for idx, batch in enumerate(lp): + x = batch.to(self.device) + with torch.no_grad(): + qres = self.model(x) + + y_pred = qres.x.cpu() + y = batch.cpu() # should already be on CPU but just in case + pendings.append(pool.submit(evaluate_audio_reconstruction, y_pred, y, self.cfg)) + + metrics_lp = self.log_progress(f'{evaluate_stage_name} metrics', pendings, updates=self.log_updates) + for pending in metrics_lp: + metrics = pending.result() + metrics = average(metrics) + + metrics = flashy.distrib.average_metrics(metrics, len(loader)) + return metrics + + def generate(self): + """Generate stage.""" + self.model.eval() + sample_manager = SampleManager(self.xp, map_reference_to_sample_id=True) + generate_stage_name = str(self.current_stage) + + loader = self.dataloaders['generate'] + updates = len(loader) + lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates) + + for batch in lp: + reference, _ = batch + reference = reference.to(self.device) + with torch.no_grad(): + qres = self.model(reference) + assert isinstance(qres, quantization.QuantizedResult) + + reference = reference.cpu() + estimate = qres.x.cpu() + sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference) + + flashy.distrib.barrier() + + def load_from_pretrained(self, name: str) -> dict: + model = models.CompressionModel.get_pretrained(name) + if isinstance(model, models.DAC): + raise RuntimeError("Cannot fine tune a DAC model.") + elif isinstance(model, models.HFEncodecCompressionModel): + self.logger.warning('Trying to automatically convert a HuggingFace model ' + 'to AudioCraft, this might fail!') + state = model.model.state_dict() + new_state = {} + for k, v in state.items(): + if k.startswith('decoder.layers') and '.conv.' in k and '.block.' not in k: + # We need to determine if this a convtr or a regular conv. + layer = int(k.split('.')[2]) + if isinstance(model.model.decoder.layers[layer].conv, torch.nn.ConvTranspose1d): + + k = k.replace('.conv.', '.convtr.') + k = k.replace('encoder.layers.', 'encoder.model.') + k = k.replace('decoder.layers.', 'decoder.model.') + k = k.replace('conv.', 'conv.conv.') + k = k.replace('convtr.', 'convtr.convtr.') + k = k.replace('quantizer.layers.', 'quantizer.vq.layers.') + k = k.replace('.codebook.', '._codebook.') + new_state[k] = v + state = new_state + elif isinstance(model, models.EncodecModel): + state = model.state_dict() + else: + raise RuntimeError(f"Cannot fine tune model type {type(model)}.") + return { + 'best_state': {'model': state} + } + + @staticmethod + def model_from_checkpoint(checkpoint_path: tp.Union[Path, str], + device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel: + """Instantiate a CompressionModel from a given checkpoint path or dora sig. + This method is a convenient endpoint to load a CompressionModel to use in other solvers. + + Args: + checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved. + This also supports pre-trained models by using a path of the form //pretrained/NAME. + See `model_from_pretrained` for a list of supported pretrained models. + use_ema (bool): Use EMA variant of the model instead of the actual model. + device (torch.device or str): Device on which the model is loaded. + """ + checkpoint_path = str(checkpoint_path) + if checkpoint_path.startswith('//pretrained/'): + name = checkpoint_path.split('/', 3)[-1] + return models.CompressionModel.get_pretrained(name, device) + logger = logging.getLogger(__name__) + logger.info(f"Loading compression model from checkpoint: {checkpoint_path}") + _checkpoint_path = checkpoint.resolve_checkpoint_path(checkpoint_path, use_fsdp=False) + assert _checkpoint_path is not None, f"Could not resolve compression model checkpoint path: {checkpoint_path}" + state = checkpoint.load_checkpoint(_checkpoint_path) + assert state is not None and 'xp.cfg' in state, f"Could not load compression model from ckpt: {checkpoint_path}" + cfg = state['xp.cfg'] + cfg.device = device + compression_model = models.builders.get_compression_model(cfg).to(device) + assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match" + + assert 'best_state' in state and state['best_state'] != {} + assert 'exported' not in state, "When loading an exported checkpoint, use the //pretrained/ prefix." + compression_model.load_state_dict(state['best_state']['model']) + compression_model.eval() + logger.info("Compression model loaded!") + return compression_model + + @staticmethod + def wrapped_model_from_checkpoint(cfg: omegaconf.DictConfig, + checkpoint_path: tp.Union[Path, str], + device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel: + """Instantiate a wrapped CompressionModel from a given checkpoint path or dora sig. + + Args: + cfg (omegaconf.DictConfig): Configuration to read from for wrapped mode. + checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved. + use_ema (bool): Use EMA variant of the model instead of the actual model. + device (torch.device or str): Device on which the model is loaded. + """ + compression_model = CompressionSolver.model_from_checkpoint(checkpoint_path, device) + compression_model = models.builders.get_wrapped_compression_model(compression_model, cfg) + return compression_model + + +def evaluate_audio_reconstruction(y_pred: torch.Tensor, y: torch.Tensor, cfg: omegaconf.DictConfig) -> dict: + """Audio reconstruction evaluation method that can be conveniently pickled.""" + metrics = {} + if cfg.evaluate.metrics.visqol: + visqol = builders.get_visqol(cfg.metrics.visqol) + metrics['visqol'] = visqol(y_pred, y, cfg.sample_rate) + sisnr = builders.get_loss('sisnr', cfg) + metrics['sisnr'] = sisnr(y_pred, y) + return metrics diff --git a/audiocraft/solvers/diffusion.py b/audiocraft/solvers/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..93dea2520836f458ab1b8514dca952b51d113ec2 --- /dev/null +++ b/audiocraft/solvers/diffusion.py @@ -0,0 +1,279 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import typing as tp + +import flashy +import julius +import omegaconf +import torch +import torch.nn.functional as F + +from . import builders +from . import base +from .. import models +from ..modules.diffusion_schedule import NoiseSchedule +from ..metrics import RelativeVolumeMel +from ..models.builders import get_processor +from ..utils.samples.manager import SampleManager +from ..solvers.compression import CompressionSolver + + +class PerStageMetrics: + """Handle prompting the metrics per stage. + It outputs the metrics per range of diffusion states. + e.g. avg loss when t in [250, 500] + """ + def __init__(self, num_steps: int, num_stages: int = 4): + self.num_steps = num_steps + self.num_stages = num_stages + + def __call__(self, losses: dict, step: tp.Union[int, torch.Tensor]): + if type(step) is int: + stage = int((step / self.num_steps) * self.num_stages) + return {f"{name}_{stage}": loss for name, loss in losses.items()} + elif type(step) is torch.Tensor: + stage_tensor = ((step / self.num_steps) * self.num_stages).long() + out: tp.Dict[str, float] = {} + for stage_idx in range(self.num_stages): + mask = (stage_tensor == stage_idx) + N = mask.sum() + stage_out = {} + if N > 0: # pass if no elements in the stage + for name, loss in losses.items(): + stage_loss = (mask * loss).sum() / N + stage_out[f"{name}_{stage_idx}"] = stage_loss + out = {**out, **stage_out} + return out + + +class DataProcess: + """Apply filtering or resampling. + + Args: + initial_sr (int): Initial sample rate. + target_sr (int): Target sample rate. + use_resampling: Whether to use resampling or not. + use_filter (bool): + n_bands (int): Number of bands to consider. + idx_band (int): + device (torch.device or str): + cutoffs (): + boost (bool): + """ + def __init__(self, initial_sr: int = 24000, target_sr: int = 16000, use_resampling: bool = False, + use_filter: bool = False, n_bands: int = 4, + idx_band: int = 0, device: torch.device = torch.device('cpu'), cutoffs=None, boost=False): + """Apply filtering or resampling + Args: + initial_sr (int): sample rate of the dataset + target_sr (int): sample rate after resampling + use_resampling (bool): whether or not performs resampling + use_filter (bool): when True filter the data to keep only one frequency band + n_bands (int): Number of bands used + cuts (none or list): The cutoff frequencies of the band filtering + if None then we use mel scale bands. + idx_band (int): index of the frequency band. 0 are lows ... (n_bands - 1) highs + boost (bool): make the data scale match our music dataset. + """ + assert idx_band < n_bands + self.idx_band = idx_band + if use_filter: + if cutoffs is not None: + self.filter = julius.SplitBands(sample_rate=initial_sr, cutoffs=cutoffs).to(device) + else: + self.filter = julius.SplitBands(sample_rate=initial_sr, n_bands=n_bands).to(device) + self.use_filter = use_filter + self.use_resampling = use_resampling + self.target_sr = target_sr + self.initial_sr = initial_sr + self.boost = boost + + def process_data(self, x, metric=False): + if x is None: + return None + if self.boost: + x /= torch.clamp(x.std(dim=(1, 2), keepdim=True), min=1e-4) + x * 0.22 + if self.use_filter and not metric: + x = self.filter(x)[self.idx_band] + if self.use_resampling: + x = julius.resample_frac(x, old_sr=self.initial_sr, new_sr=self.target_sr) + return x + + def inverse_process(self, x): + """Upsampling only.""" + if self.use_resampling: + x = julius.resample_frac(x, old_sr=self.target_sr, new_sr=self.target_sr) + return x + + +class DiffusionSolver(base.StandardSolver): + """Solver for compression task. + + The diffusion task allows for MultiBand diffusion model training. + + Args: + cfg (DictConfig): Configuration. + """ + def __init__(self, cfg: omegaconf.DictConfig): + super().__init__(cfg) + self.cfg = cfg + self.device = cfg.device + self.sample_rate: int = self.cfg.sample_rate + self.codec_model = CompressionSolver.model_from_checkpoint( + cfg.compression_model_checkpoint, device=self.device) + + self.codec_model.set_num_codebooks(cfg.n_q) + assert self.codec_model.sample_rate == self.cfg.sample_rate, ( + f"Codec model sample rate is {self.codec_model.sample_rate} but " + f"Solver sample rate is {self.cfg.sample_rate}." + ) + assert self.codec_model.sample_rate == self.sample_rate, \ + f"Sample rate of solver {self.sample_rate} and codec {self.codec_model.sample_rate} " \ + "don't match." + + self.sample_processor = get_processor(cfg.processor, sample_rate=self.sample_rate) + self.register_stateful('sample_processor') + self.sample_processor.to(self.device) + + self.schedule = NoiseSchedule( + **cfg.schedule, device=self.device, sample_processor=self.sample_processor) + + self.eval_metric: tp.Optional[torch.nn.Module] = None + + self.rvm = RelativeVolumeMel() + self.data_processor = DataProcess(initial_sr=self.sample_rate, target_sr=cfg.resampling.target_sr, + use_resampling=cfg.resampling.use, cutoffs=cfg.filter.cutoffs, + use_filter=cfg.filter.use, n_bands=cfg.filter.n_bands, + idx_band=cfg.filter.idx_band, device=self.device) + + @property + def best_metric_name(self) -> tp.Optional[str]: + if self._current_stage == "evaluate": + return 'rvm' + else: + return 'loss' + + @torch.no_grad() + def get_condition(self, wav: torch.Tensor) -> torch.Tensor: + codes, scale = self.codec_model.encode(wav) + assert scale is None, "Scaled compression models not supported." + emb = self.codec_model.decode_latent(codes) + return emb + + def build_model(self): + """Build model and optimizer as well as optional Exponential Moving Average of the model. + """ + # Model and optimizer + self.model = models.builders.get_diffusion_model(self.cfg).to(self.device) + self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim) + self.register_stateful('model', 'optimizer') + self.register_best_state('model') + self.register_ema('model') + + def build_dataloaders(self): + """Build audio dataloaders for each stage.""" + self.dataloaders = builders.get_audio_datasets(self.cfg) + + def show(self): + # TODO + raise NotImplementedError() + + def run_step(self, idx: int, batch: torch.Tensor, metrics: dict): + """Perform one training or valid step on a given batch.""" + x = batch.to(self.device) + loss_fun = F.mse_loss if self.cfg.loss.kind == 'mse' else F.l1_loss + + condition = self.get_condition(x) # [bs, 128, T/hop, n_emb] + sample = self.data_processor.process_data(x) + + input_, target, step = self.schedule.get_training_item(sample, + tensor_step=self.cfg.schedule.variable_step_batch) + out = self.model(input_, step, condition=condition).sample + + base_loss = loss_fun(out, target, reduction='none').mean(dim=(1, 2)) + reference_loss = loss_fun(input_, target, reduction='none').mean(dim=(1, 2)) + loss = base_loss / reference_loss ** self.cfg.loss.norm_power + + if self.is_training: + loss.mean().backward() + flashy.distrib.sync_model(self.model) + self.optimizer.step() + self.optimizer.zero_grad() + metrics = { + 'loss': loss.mean(), 'normed_loss': (base_loss / reference_loss).mean(), + } + metrics.update(self.per_stage({'loss': loss, 'normed_loss': base_loss / reference_loss}, step)) + metrics.update({ + 'std_in': input_.std(), 'std_out': out.std()}) + return metrics + + def run_epoch(self): + # reset random seed at the beginning of the epoch + self.rng = torch.Generator() + self.rng.manual_seed(1234 + self.epoch) + self.per_stage = PerStageMetrics(self.schedule.num_steps, self.cfg.metrics.num_stage) + # run epoch + super().run_epoch() + + def evaluate(self): + """Evaluate stage. + Runs audio reconstruction evaluation. + """ + self.model.eval() + evaluate_stage_name = f'{self.current_stage}' + loader = self.dataloaders['evaluate'] + updates = len(loader) + lp = self.log_progress(f'{evaluate_stage_name} estimate', loader, total=updates, updates=self.log_updates) + + metrics = {} + n = 1 + for idx, batch in enumerate(lp): + x = batch.to(self.device) + with torch.no_grad(): + y_pred = self.regenerate(x) + + y_pred = y_pred.cpu() + y = batch.cpu() # should already be on CPU but just in case + rvm = self.rvm(y_pred, y) + lp.update(**rvm) + if len(metrics) == 0: + metrics = rvm + else: + for key in rvm.keys(): + metrics[key] = (metrics[key] * n + rvm[key]) / (n + 1) + metrics = flashy.distrib.average_metrics(metrics) + return metrics + + @torch.no_grad() + def regenerate(self, wav: torch.Tensor, step_list: tp.Optional[list] = None): + """Regenerate the given waveform.""" + condition = self.get_condition(wav) + initial = self.schedule.get_initial_noise(self.data_processor.process_data(wav)) # sampling rate changes. + result = self.schedule.generate_subsampled(self.model, initial=initial, condition=condition, + step_list=step_list) + result = self.data_processor.inverse_process(result) + return result + + def generate(self): + """Generate stage.""" + sample_manager = SampleManager(self.xp) + self.model.eval() + generate_stage_name = f'{self.current_stage}' + + loader = self.dataloaders['generate'] + updates = len(loader) + lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates) + + for batch in lp: + reference, _ = batch + reference = reference.to(self.device) + estimate = self.regenerate(reference) + reference = reference.cpu() + estimate = estimate.cpu() + sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference) + flashy.distrib.barrier() diff --git a/audiocraft/solvers/musicgen.py b/audiocraft/solvers/musicgen.py new file mode 100644 index 0000000000000000000000000000000000000000..2439da33e5e3cf78a526fcc1a5d630a349e735ed --- /dev/null +++ b/audiocraft/solvers/musicgen.py @@ -0,0 +1,705 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from pathlib import Path +import time +import typing as tp +import warnings + +import flashy +import math +import omegaconf +import torch +from torch.nn import functional as F + +from . import base, builders +from .compression import CompressionSolver +from .. import metrics as eval_metrics +from .. import models +from ..data.audio_dataset import AudioDataset +from ..data.music_dataset import MusicDataset, MusicInfo, AudioInfo +from ..data.audio_utils import normalize_audio +from ..modules.conditioners import JointEmbedCondition, SegmentWithAttributes, WavCondition +from ..utils.cache import CachedBatchWriter, CachedBatchLoader +from ..utils.samples.manager import SampleManager +from ..utils.utils import get_dataset_from_loader, is_jsonable, warn_once + + +class MusicGenSolver(base.StandardSolver): + """Solver for MusicGen training task. + + Used in: https://arxiv.org/abs/2306.05284 + """ + DATASET_TYPE: builders.DatasetType = builders.DatasetType.MUSIC + + def __init__(self, cfg: omegaconf.DictConfig): + super().__init__(cfg) + # easier access to sampling parameters + self.generation_params = { + 'use_sampling': self.cfg.generate.lm.use_sampling, + 'temp': self.cfg.generate.lm.temp, + 'top_k': self.cfg.generate.lm.top_k, + 'top_p': self.cfg.generate.lm.top_p, + } + self._best_metric_name: tp.Optional[str] = 'ce' + + self._cached_batch_writer = None + self._cached_batch_loader = None + if cfg.cache.path: + if cfg.cache.write: + self._cached_batch_writer = CachedBatchWriter(Path(cfg.cache.path)) + if self.cfg.cache.write_num_shards: + self.logger.warning("Multiple shard cache, best_metric_name will be set to None.") + self._best_metric_name = None + else: + self._cached_batch_loader = CachedBatchLoader( + Path(cfg.cache.path), cfg.dataset.batch_size, cfg.dataset.num_workers, + min_length=self.cfg.optim.updates_per_epoch or 1) + self.dataloaders['original_train'] = self.dataloaders['train'] + self.dataloaders['train'] = self._cached_batch_loader # type: ignore + + @staticmethod + def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None, + device: tp.Optional[str] = None, autocast: bool = True, + batch_size: tp.Optional[int] = None, + override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None, + **kwargs): + """Mostly a convenience function around magma.train.get_solver_from_sig, + populating all the proper param, deactivating EMA, FSDP, loading the best state, + basically all you need to get a solver ready to "play" with in single GPU mode + and with minimal memory overhead. + + Args: + sig (str): signature to load. + dtype (str or None): potential dtype, as a string, i.e. 'float16'. + device (str or None): potential device, as a string, i.e. 'cuda'. + override_cfg (dict or omegaconf.DictConfig or None): potential device, as a string, i.e. 'cuda'. + """ + from audiocraft import train + our_override_cfg: tp.Dict[str, tp.Any] = {'optim': {'ema': {'use': False}}} + our_override_cfg['autocast'] = autocast + if dtype is not None: + our_override_cfg['dtype'] = dtype + if device is not None: + our_override_cfg['device'] = device + if batch_size is not None: + our_override_cfg['dataset'] = {'batch_size': batch_size} + if override_cfg is None: + override_cfg = {} + override_cfg = omegaconf.OmegaConf.merge( + omegaconf.DictConfig(override_cfg), omegaconf.DictConfig(our_override_cfg)) # type: ignore + solver = train.get_solver_from_sig( + sig, override_cfg=override_cfg, + load_best=True, disable_fsdp=True, + ignore_state_keys=['optimizer', 'ema'], **kwargs) + solver.model.eval() + return solver + + def get_formatter(self, stage_name: str) -> flashy.Formatter: + return flashy.Formatter({ + 'lr': '.2E', + 'ce': '.3f', + 'ppl': '.3f', + 'grad_norm': '.3E', + }, exclude_keys=['ce_q*', 'ppl_q*']) + + @property + def best_metric_name(self) -> tp.Optional[str]: + return self._best_metric_name + + def build_model(self) -> None: + """Instantiate models and optimizer.""" + # we can potentially not use all quantizers with which the EnCodec model was trained + # (e.g. we trained the model with quantizers dropout) + self.compression_model = CompressionSolver.wrapped_model_from_checkpoint( + self.cfg, self.cfg.compression_model_checkpoint, device=self.device) + assert self.compression_model.sample_rate == self.cfg.sample_rate, ( + f"Compression model sample rate is {self.compression_model.sample_rate} but " + f"Solver sample rate is {self.cfg.sample_rate}." + ) + # ensure we have matching configuration between LM and compression model + assert self.cfg.transformer_lm.card == self.compression_model.cardinality, ( + "Cardinalities of the LM and compression model don't match: ", + f"LM cardinality is {self.cfg.transformer_lm.card} vs ", + f"compression model cardinality is {self.compression_model.cardinality}" + ) + assert self.cfg.transformer_lm.n_q == self.compression_model.num_codebooks, ( + "Numbers of codebooks of the LM and compression models don't match: ", + f"LM number of codebooks is {self.cfg.transformer_lm.n_q} vs ", + f"compression model numer of codebooks is {self.compression_model.num_codebooks}" + ) + self.logger.info("Compression model has %d codebooks with %d cardinality, and a framerate of %d", + self.compression_model.num_codebooks, self.compression_model.cardinality, + self.compression_model.frame_rate) + # instantiate LM model + self.model: models.LMModel = models.builders.get_lm_model(self.cfg).to(self.device) + if self.cfg.fsdp.use: + assert not self.cfg.autocast, "Cannot use autocast with fsdp" + self.model = self.wrap_with_fsdp(self.model) + self.register_ema('model') + # initialize optimization + self.optimizer = builders.get_optimizer(builders.get_optim_parameter_groups(self.model), self.cfg.optim) + self.lr_scheduler = builders.get_lr_scheduler(self.optimizer, self.cfg.schedule, self.total_updates) + self.register_stateful('compression_model', 'model', 'optimizer', 'lr_scheduler') + self.register_best_state('model') + self.autocast_dtype = { + 'float16': torch.float16, 'bfloat16': torch.bfloat16 + }[self.cfg.autocast_dtype] + self.scaler: tp.Optional[torch.cuda.amp.GradScaler] = None + if self.cfg.fsdp.use: + need_scaler = self.cfg.fsdp.param_dtype == 'float16' + else: + need_scaler = self.cfg.autocast and self.autocast_dtype is torch.float16 + if need_scaler: + if self.cfg.fsdp.use: + from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler + self.scaler = ShardedGradScaler() # type: ignore + else: + self.scaler = torch.cuda.amp.GradScaler() + self.register_stateful('scaler') + + def build_dataloaders(self) -> None: + """Instantiate audio dataloaders for each stage.""" + self.dataloaders = builders.get_audio_datasets(self.cfg, dataset_type=self.DATASET_TYPE) + + def show(self) -> None: + """Show the compression model and LM model.""" + self.logger.info("Compression model:") + self.log_model_summary(self.compression_model) + self.logger.info("LM model:") + self.log_model_summary(self.model) + + def load_state_dict(self, state: dict) -> None: + if 'condition_provider' in state: + model_state = state['model'] + condition_provider_state = state.pop('condition_provider') + prefix = 'condition_provider.' + for key, value in condition_provider_state.items(): + key = prefix + key + assert key not in model_state + model_state[key] = value + super().load_state_dict(state) + + def load_from_pretrained(self, name: str): + # TODO: support native HF versions of MusicGen. + lm_pkg = models.loaders.load_lm_model_ckpt(name) + state: dict = { + 'best_state': { + 'model': lm_pkg['best_state'], + }, + } + return state + + def _compute_cross_entropy( + self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor + ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]: + """Compute cross entropy between multi-codebook targets and model's logits. + The cross entropy is computed per codebook to provide codebook-level cross entropy. + Valid timesteps for each of the codebook are pulled from the mask, where invalid + timesteps are set to 0. + + Args: + logits (torch.Tensor): Model's logits of shape [B, K, T, card]. + targets (torch.Tensor): Target codes, of shape [B, K, T]. + mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T]. + Returns: + ce (torch.Tensor): Cross entropy averaged over the codebooks + ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached). + """ + B, K, T = targets.shape + assert logits.shape[:-1] == targets.shape + assert mask.shape == targets.shape + ce = torch.zeros([], device=targets.device) + ce_per_codebook: tp.List[torch.Tensor] = [] + for k in range(K): + logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card] + targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T] + mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T] + ce_targets = targets_k[mask_k] + ce_logits = logits_k[mask_k] + q_ce = F.cross_entropy(ce_logits, ce_targets) + ce += q_ce + ce_per_codebook.append(q_ce.detach()) + # average cross entropy across codebooks + ce = ce / K + return ce, ce_per_codebook + + def _prepare_tokens_and_attributes( + self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], + check_synchronization_points: bool = False + ) -> tp.Tuple[dict, torch.Tensor, torch.Tensor]: + """Prepare input batchs for language model training. + + Args: + batch (tuple[torch.Tensor, list[SegmentWithAttributes]]): Input batch with audio tensor of shape [B, C, T] + and corresponding metadata as SegmentWithAttributes (with B items). + check_synchronization_points (bool): Whether to check for synchronization points slowing down training. + Returns: + Condition tensors (dict[str, any]): Preprocessed condition attributes. + Tokens (torch.Tensor): Audio tokens from compression model, of shape [B, K, T_s], + with B the batch size, K the number of codebooks, T_s the token timesteps. + Padding mask (torch.Tensor): Mask with valid positions in the tokens tensor, of shape [B, K, T_s]. + """ + if self.model.training: + warnings.warn( + "Up to version 1.0.1, the _prepare_tokens_and_attributes was evaluated with `torch.no_grad()`. " + "This is inconsistent with how model were trained in the MusicGen paper. We removed the " + "`torch.no_grad()` in version 1.1.0. Small changes to the final performance are expected. " + "Really sorry about that.") + if self._cached_batch_loader is None or self.current_stage != "train": + audio, infos = batch + audio = audio.to(self.device) + audio_tokens = None + assert audio.size(0) == len(infos), ( + f"Mismatch between number of items in audio batch ({audio.size(0)})", + f" and in metadata ({len(infos)})" + ) + else: + audio = None + # In that case the batch will be a tuple coming from the _cached_batch_writer bit below. + infos, = batch # type: ignore + assert all([isinstance(info, AudioInfo) for info in infos]) + assert all([info.audio_tokens is not None for info in infos]) # type: ignore + audio_tokens = torch.stack([info.audio_tokens for info in infos]).to(self.device) # type: ignore + audio_tokens = audio_tokens.long() + for info in infos: + if isinstance(info, MusicInfo): + # Careful here, if you want to use this condition_wav (e.b. chroma conditioning), + # then you must be using the chroma cache! otherwise the code will try + # to use this segment and fail (by that I mean you will see NaN everywhere). + info.self_wav = WavCondition( + torch.full([1, info.channels, info.total_frames], float('NaN')), + length=torch.tensor([info.n_frames]), + sample_rate=[info.sample_rate], + path=[info.meta.path], + seek_time=[info.seek_time]) + dataset = get_dataset_from_loader(self.dataloaders['original_train']) + assert isinstance(dataset, MusicDataset), type(dataset) + if dataset.paraphraser is not None and info.description is not None: + # Hackingly reapplying paraphraser when using cache. + info.description = dataset.paraphraser.sample_paraphrase( + info.meta.path, info.description) + # prepare attributes + attributes = [info.to_condition_attributes() for info in infos] + attributes = self.model.cfg_dropout(attributes) + attributes = self.model.att_dropout(attributes) + tokenized = self.model.condition_provider.tokenize(attributes) + + # Now we should be synchronization free. + if self.device == "cuda" and check_synchronization_points: + torch.cuda.set_sync_debug_mode("warn") + + if audio_tokens is None: + with torch.no_grad(): + audio_tokens, scale = self.compression_model.encode(audio) + assert scale is None, "Scaled compression model not supported with LM." + + with self.autocast: + condition_tensors = self.model.condition_provider(tokenized) + + # create a padding mask to hold valid vs invalid positions + padding_mask = torch.ones_like(audio_tokens, dtype=torch.bool, device=audio_tokens.device) + # replace encodec tokens from padded audio with special_token_id + if self.cfg.tokens.padding_with_special_token: + audio_tokens = audio_tokens.clone() + padding_mask = padding_mask.clone() + token_sample_rate = self.compression_model.frame_rate + B, K, T_s = audio_tokens.shape + for i in range(B): + n_samples = infos[i].n_frames + audio_sample_rate = infos[i].sample_rate + # take the last token generated from actual audio frames (non-padded audio) + valid_tokens = math.floor(float(n_samples) / audio_sample_rate * token_sample_rate) + audio_tokens[i, :, valid_tokens:] = self.model.special_token_id + padding_mask[i, :, valid_tokens:] = 0 + + if self.device == "cuda" and check_synchronization_points: + torch.cuda.set_sync_debug_mode("default") + + if self._cached_batch_writer is not None and self.current_stage == 'train': + assert self._cached_batch_loader is None + assert audio_tokens is not None + for info, one_audio_tokens in zip(infos, audio_tokens): + assert isinstance(info, AudioInfo) + if isinstance(info, MusicInfo): + assert not info.joint_embed, "joint_embed and cache not supported yet." + info.self_wav = None + assert one_audio_tokens.max() < 2**15, one_audio_tokens.max().item() + info.audio_tokens = one_audio_tokens.short().cpu() + self._cached_batch_writer.save(infos) + + return condition_tensors, audio_tokens, padding_mask + + def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], metrics: dict) -> dict: + """Perform one training or valid step on a given batch.""" + check_synchronization_points = idx == 1 and self.device == 'cuda' + + condition_tensors, audio_tokens, padding_mask = self._prepare_tokens_and_attributes( + batch, check_synchronization_points) + + self.deadlock_detect.update('tokens_and_conditions') + + if check_synchronization_points: + torch.cuda.set_sync_debug_mode('warn') + + with self.autocast: + model_output = self.model.compute_predictions(audio_tokens, [], condition_tensors) # type: ignore + logits = model_output.logits + mask = padding_mask & model_output.mask + ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask) + loss = ce + self.deadlock_detect.update('loss') + + if check_synchronization_points: + torch.cuda.set_sync_debug_mode('default') + + if self.is_training: + metrics['lr'] = self.optimizer.param_groups[0]['lr'] + if self.scaler is not None: + loss = self.scaler.scale(loss) + self.deadlock_detect.update('scale') + if self.cfg.fsdp.use: + loss.backward() + flashy.distrib.average_tensors(self.model.buffers()) + elif self.cfg.optim.eager_sync: + with flashy.distrib.eager_sync_model(self.model): + loss.backward() + else: + # this should always be slower but can be useful + # for weird use cases like multiple backwards. + loss.backward() + flashy.distrib.sync_model(self.model) + self.deadlock_detect.update('backward') + + if self.scaler is not None: + self.scaler.unscale_(self.optimizer) + if self.cfg.optim.max_norm: + if self.cfg.fsdp.use: + metrics['grad_norm'] = self.model.clip_grad_norm_(self.cfg.optim.max_norm) # type: ignore + else: + metrics['grad_norm'] = torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.cfg.optim.max_norm + ) + if self.scaler is None: + self.optimizer.step() + else: + self.scaler.step(self.optimizer) + self.scaler.update() + if self.lr_scheduler: + self.lr_scheduler.step() + self.optimizer.zero_grad() + self.deadlock_detect.update('optim') + if self.scaler is not None: + scale = self.scaler.get_scale() + metrics['grad_scale'] = scale + if not loss.isfinite().all(): + raise RuntimeError("Model probably diverged.") + + metrics['ce'] = ce + metrics['ppl'] = torch.exp(ce) + for k, ce_q in enumerate(ce_per_codebook): + metrics[f'ce_q{k + 1}'] = ce_q + metrics[f'ppl_q{k + 1}'] = torch.exp(ce_q) + + return metrics + + @torch.no_grad() + def run_generate_step(self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], + gen_duration: float, prompt_duration: tp.Optional[float] = None, + remove_prompt: bool = False, + **generation_params) -> dict: + """Run generate step on a batch of optional audio tensor and corresponding attributes. + + Args: + batch (tuple[torch.Tensor, list[SegmentWithAttributes]]): + use_prompt (bool): Whether to do audio continuation generation with prompt from audio batch. + gen_duration (float): Target audio duration for the generation. + prompt_duration (float, optional): Duration for the audio prompt to use for continuation. + remove_prompt (bool, optional): Whether to remove the prompt from the generated audio. + generation_params: Additional generation parameters. + Returns: + gen_outputs (dict): Generation outputs, consisting in audio, audio tokens from both the generation + and the prompt along with additional information. + """ + bench_start = time.time() + audio, meta = batch + assert audio.size(0) == len(meta), ( + f"Mismatch between number of items in audio batch ({audio.size(0)})", + f" and in metadata ({len(meta)})" + ) + # prepare attributes + attributes = [x.to_condition_attributes() for x in meta] + # TODO: Add dropout for chroma? + + # prepare audio prompt + if prompt_duration is None: + prompt_audio = None + else: + assert prompt_duration < gen_duration, "Prompt duration must be lower than target generation duration" + prompt_audio_frames = int(prompt_duration * self.compression_model.sample_rate) + prompt_audio = audio[..., :prompt_audio_frames] + + # get audio tokens from compression model + if prompt_audio is None or prompt_audio.nelement() == 0: + num_samples = len(attributes) + prompt_tokens = None + else: + num_samples = None + prompt_audio = prompt_audio.to(self.device) + prompt_tokens, scale = self.compression_model.encode(prompt_audio) + assert scale is None, "Compression model in MusicGen should not require rescaling." + + # generate by sampling from the LM + with self.autocast: + total_gen_len = math.ceil(gen_duration * self.compression_model.frame_rate) + gen_tokens = self.model.generate( + prompt_tokens, attributes, max_gen_len=total_gen_len, + num_samples=num_samples, **self.generation_params) + + # generate audio from tokens + assert gen_tokens.dim() == 3 + gen_audio = self.compression_model.decode(gen_tokens, None) + + bench_end = time.time() + gen_outputs = { + 'rtf': (bench_end - bench_start) / gen_duration, + 'ref_audio': audio, + 'gen_audio': gen_audio, + 'gen_tokens': gen_tokens, + 'prompt_audio': prompt_audio, + 'prompt_tokens': prompt_tokens, + } + return gen_outputs + + def generate_audio(self) -> dict: + """Audio generation stage.""" + generate_stage_name = f'{self.current_stage}' + sample_manager = SampleManager(self.xp) + self.logger.info(f"Generating samples in {sample_manager.base_folder}") + loader = self.dataloaders['generate'] + updates = len(loader) + lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates) + + dataset = get_dataset_from_loader(loader) + dataset_duration = dataset.segment_duration + assert dataset_duration is not None + assert isinstance(dataset, AudioDataset) + target_duration = self.cfg.generate.lm.gen_duration + prompt_duration = self.cfg.generate.lm.prompt_duration + if target_duration is None: + target_duration = dataset_duration + if prompt_duration is None: + prompt_duration = dataset_duration / 4 + assert prompt_duration < dataset_duration, ( + f"Specified prompt duration ({prompt_duration}s) is longer", + f" than reference audio duration ({dataset_duration}s)" + ) + + def get_hydrated_conditions(meta: tp.List[SegmentWithAttributes]): + hydrated_conditions = [] + for sample in [x.to_condition_attributes() for x in meta]: + cond_dict = {} + for cond_type in sample.__annotations__.keys(): + for cond_key, cond_val in getattr(sample, cond_type).items(): + if cond_key not in self.model.condition_provider.conditioners.keys(): + continue + if is_jsonable(cond_val): + cond_dict[cond_key] = cond_val + elif isinstance(cond_val, WavCondition): + cond_dict[cond_key] = cond_val.path + elif isinstance(cond_val, JointEmbedCondition): + cond_dict[cond_key] = cond_val.text # only support text at inference for now + else: + # if we reached this point, it is not clear how to log the condition + # so we just log the type. + cond_dict[cond_key] = str(type(cond_val)) + continue + hydrated_conditions.append(cond_dict) + return hydrated_conditions + + metrics: dict = {} + average = flashy.averager() + for batch in lp: + audio, meta = batch + # metadata for sample manager + hydrated_conditions = get_hydrated_conditions(meta) + sample_generation_params = { + **{f'classifier_free_guidance_{k}': v for k, v in self.cfg.classifier_free_guidance.items()}, + **self.generation_params + } + if self.cfg.generate.lm.unprompted_samples: + if self.cfg.generate.lm.gen_gt_samples: + # get the ground truth instead of generation + self.logger.warn( + "Use ground truth instead of audio generation as generate.lm.gen_gt_samples=true") + gen_unprompted_audio = audio + rtf = 1. + else: + gen_unprompted_outputs = self.run_generate_step( + batch, gen_duration=target_duration, prompt_duration=None, + **self.generation_params) + gen_unprompted_audio = gen_unprompted_outputs['gen_audio'].cpu() + rtf = gen_unprompted_outputs['rtf'] + sample_manager.add_samples( + gen_unprompted_audio, self.epoch, hydrated_conditions, + ground_truth_wavs=audio, generation_args=sample_generation_params) + + if self.cfg.generate.lm.prompted_samples: + gen_outputs = self.run_generate_step( + batch, gen_duration=target_duration, prompt_duration=prompt_duration, + **self.generation_params) + gen_audio = gen_outputs['gen_audio'].cpu() + prompt_audio = gen_outputs['prompt_audio'].cpu() + sample_manager.add_samples( + gen_audio, self.epoch, hydrated_conditions, + prompt_wavs=prompt_audio, ground_truth_wavs=audio, + generation_args=sample_generation_params) + + metrics['rtf'] = rtf + metrics = average(metrics) + + flashy.distrib.barrier() + return metrics + + def generate(self) -> dict: + """Generate stage.""" + self.model.eval() + with torch.no_grad(): + return self.generate_audio() + + def run_epoch(self): + if self.cfg.cache.write: + if ((self.epoch - 1) % self.cfg.cache.write_num_shards) != self.cfg.cache.write_shard: + return + super().run_epoch() + + def train(self): + """Train stage. + """ + if self._cached_batch_writer is not None: + self._cached_batch_writer.start_epoch(self.epoch) + if self._cached_batch_loader is None: + dataset = get_dataset_from_loader(self.dataloaders['train']) + assert isinstance(dataset, AudioDataset) + dataset.current_epoch = self.epoch + else: + self._cached_batch_loader.start_epoch(self.epoch) + return super().train() + + def evaluate_audio_generation(self) -> dict: + """Evaluate audio generation with off-the-shelf metrics.""" + evaluate_stage_name = f'{self.current_stage}_generation' + # instantiate evaluation metrics, if at least one metric is defined, run audio generation evaluation + fad: tp.Optional[eval_metrics.FrechetAudioDistanceMetric] = None + kldiv: tp.Optional[eval_metrics.KLDivergenceMetric] = None + text_consistency: tp.Optional[eval_metrics.TextConsistencyMetric] = None + chroma_cosine: tp.Optional[eval_metrics.ChromaCosineSimilarityMetric] = None + should_run_eval = False + eval_chroma_wavs: tp.Optional[torch.Tensor] = None + if self.cfg.evaluate.metrics.fad: + fad = builders.get_fad(self.cfg.metrics.fad).to(self.device) + should_run_eval = True + if self.cfg.evaluate.metrics.kld: + kldiv = builders.get_kldiv(self.cfg.metrics.kld).to(self.device) + should_run_eval = True + if self.cfg.evaluate.metrics.text_consistency: + text_consistency = builders.get_text_consistency(self.cfg.metrics.text_consistency).to(self.device) + should_run_eval = True + if self.cfg.evaluate.metrics.chroma_cosine: + chroma_cosine = builders.get_chroma_cosine_similarity(self.cfg.metrics.chroma_cosine).to(self.device) + # if we have predefind wavs for chroma we should purge them for computing the cosine metric + has_predefined_eval_chromas = 'self_wav' in self.model.condition_provider.conditioners and \ + self.model.condition_provider.conditioners['self_wav'].has_eval_wavs() + if has_predefined_eval_chromas: + warn_once(self.logger, "Attempting to run cosine eval for config with pre-defined eval chromas! " + 'Resetting eval chromas to None for evaluation.') + eval_chroma_wavs = self.model.condition_provider.conditioners.self_wav.eval_wavs # type: ignore + self.model.condition_provider.conditioners.self_wav.reset_eval_wavs(None) # type: ignore + should_run_eval = True + + def get_compressed_audio(audio: torch.Tensor) -> torch.Tensor: + audio_tokens, scale = self.compression_model.encode(audio.to(self.device)) + compressed_audio = self.compression_model.decode(audio_tokens, scale) + return compressed_audio[..., :audio.shape[-1]] + + metrics: dict = {} + if should_run_eval: + loader = self.dataloaders['evaluate'] + updates = len(loader) + lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates) + average = flashy.averager() + dataset = get_dataset_from_loader(loader) + assert isinstance(dataset, AudioDataset) + self.logger.info(f"Computing evaluation metrics on {len(dataset)} samples") + + for idx, batch in enumerate(lp): + audio, meta = batch + assert all([self.cfg.sample_rate == m.sample_rate for m in meta]) + + target_duration = audio.shape[-1] / self.cfg.sample_rate + if self.cfg.evaluate.fixed_generation_duration: + target_duration = self.cfg.evaluate.fixed_generation_duration + + gen_outputs = self.run_generate_step( + batch, gen_duration=target_duration, + **self.generation_params + ) + y_pred = gen_outputs['gen_audio'].detach() + y_pred = y_pred[..., :audio.shape[-1]] + + normalize_kwargs = dict(self.cfg.generate.audio) + normalize_kwargs.pop('format', None) + y_pred = torch.stack([normalize_audio(w, **normalize_kwargs) for w in y_pred], dim=0).cpu() + y = audio.cpu() # should already be on CPU but just in case + sizes = torch.tensor([m.n_frames for m in meta]) # actual sizes without padding + sample_rates = torch.tensor([m.sample_rate for m in meta]) # sample rates for audio samples + audio_stems = [Path(m.meta.path).stem + f"_{m.seek_time}" for m in meta] + + if fad is not None: + if self.cfg.metrics.fad.use_gt: + y_pred = get_compressed_audio(y).cpu() + fad.update(y_pred, y, sizes, sample_rates, audio_stems) + if kldiv is not None: + if self.cfg.metrics.kld.use_gt: + y_pred = get_compressed_audio(y).cpu() + kldiv.update(y_pred, y, sizes, sample_rates) + if text_consistency is not None: + texts = [m.description for m in meta] + if self.cfg.metrics.text_consistency.use_gt: + y_pred = y + text_consistency.update(y_pred, texts, sizes, sample_rates) + if chroma_cosine is not None: + if self.cfg.metrics.chroma_cosine.use_gt: + y_pred = get_compressed_audio(y).cpu() + chroma_cosine.update(y_pred, y, sizes, sample_rates) + # restore chroma conditioner's eval chroma wavs + if eval_chroma_wavs is not None: + self.model.condition_provider.conditioners['self_wav'].reset_eval_wavs(eval_chroma_wavs) + + flashy.distrib.barrier() + if fad is not None: + metrics['fad'] = fad.compute() + if kldiv is not None: + kld_metrics = kldiv.compute() + metrics.update(kld_metrics) + if text_consistency is not None: + metrics['text_consistency'] = text_consistency.compute() + if chroma_cosine is not None: + metrics['chroma_cosine'] = chroma_cosine.compute() + metrics = average(metrics) + metrics = flashy.distrib.average_metrics(metrics, len(loader)) + + return metrics + + def evaluate(self) -> dict: + """Evaluate stage.""" + self.model.eval() + with torch.no_grad(): + metrics: dict = {} + if self.cfg.evaluate.metrics.base: + metrics.update(self.common_train_valid('evaluate')) + gen_metrics = self.evaluate_audio_generation() + return {**metrics, **gen_metrics} diff --git a/audiocraft/train.py b/audiocraft/train.py new file mode 100644 index 0000000000000000000000000000000000000000..5851222c39e173f91dc9dafe962470c52cf2fba6 --- /dev/null +++ b/audiocraft/train.py @@ -0,0 +1,163 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Entry point for dora to launch solvers for running training loops. +See more info on how to use dora: https://github.com/facebookresearch/dora +""" + +import logging +import multiprocessing +import os +from pathlib import Path +import sys +import typing as tp + +from dora import git_save, hydra_main, XP +import flashy +import hydra +import omegaconf + +from .environment import AudioCraftEnvironment +from .utils.cluster import get_slurm_parameters + +logger = logging.getLogger(__name__) + + +def resolve_config_dset_paths(cfg): + """Enable Dora to load manifest from git clone repository.""" + # manifest files for the different splits + for key, value in cfg.datasource.items(): + if isinstance(value, str): + cfg.datasource[key] = git_save.to_absolute_path(value) + + +def get_solver(cfg): + from . import solvers + # Convert batch size to batch size for each GPU + assert cfg.dataset.batch_size % flashy.distrib.world_size() == 0 + cfg.dataset.batch_size //= flashy.distrib.world_size() + for split in ['train', 'valid', 'evaluate', 'generate']: + if hasattr(cfg.dataset, split) and hasattr(cfg.dataset[split], 'batch_size'): + assert cfg.dataset[split].batch_size % flashy.distrib.world_size() == 0 + cfg.dataset[split].batch_size //= flashy.distrib.world_size() + resolve_config_dset_paths(cfg) + solver = solvers.get_solver(cfg) + return solver + + +def get_solver_from_xp(xp: XP, override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None, + restore: bool = True, load_best: bool = True, + ignore_state_keys: tp.List[str] = [], disable_fsdp: bool = True): + """Given a XP, return the Solver object. + + Args: + xp (XP): Dora experiment for which to retrieve the solver. + override_cfg (dict or None): If not None, should be a dict used to + override some values in the config of `xp`. This will not impact + the XP signature or folder. The format is different + than the one used in Dora grids, nested keys should actually be nested dicts, + not flattened, e.g. `{'optim': {'batch_size': 32}}`. + restore (bool): If `True` (the default), restore state from the last checkpoint. + load_best (bool): If `True` (the default), load the best state from the checkpoint. + ignore_state_keys (list[str]): List of sources to ignore when loading the state, e.g. `optimizer`. + disable_fsdp (bool): if True, disables FSDP entirely. This will + also automatically skip loading the EMA. For solver specific + state sources, like the optimizer, you might want to + use along `ignore_state_keys=['optimizer']`. Must be used with `load_best=True`. + """ + logger.info(f"Loading solver from XP {xp.sig}. " + f"Overrides used: {xp.argv}") + cfg = xp.cfg + if override_cfg is not None: + cfg = omegaconf.OmegaConf.merge(cfg, omegaconf.DictConfig(override_cfg)) + if disable_fsdp and cfg.fsdp.use: + cfg.fsdp.use = False + assert load_best is True + # ignoring some keys that were FSDP sharded like model, ema, and best_state. + # fsdp_best_state will be used in that case. When using a specific solver, + # one is responsible for adding the relevant keys, e.g. 'optimizer'. + # We could make something to automatically register those inside the solver, but that + # seem overkill at this point. + ignore_state_keys = ignore_state_keys + ['model', 'ema', 'best_state'] + + try: + with xp.enter(): + solver = get_solver(cfg) + if restore: + solver.restore(load_best=load_best, ignore_state_keys=ignore_state_keys) + return solver + finally: + hydra.core.global_hydra.GlobalHydra.instance().clear() + + +def get_solver_from_sig(sig: str, *args, **kwargs): + """Return Solver object from Dora signature, i.e. to play with it from a notebook. + See `get_solver_from_xp` for more information. + """ + xp = main.get_xp_from_sig(sig) + return get_solver_from_xp(xp, *args, **kwargs) + + +def init_seed_and_system(cfg): + import numpy as np + import torch + import random + from audiocraft.modules.transformer import set_efficient_attention_backend + + multiprocessing.set_start_method(cfg.mp_start_method) + logger.debug('Setting mp start method to %s', cfg.mp_start_method) + random.seed(cfg.seed) + np.random.seed(cfg.seed) + # torch also initialize cuda seed if available + torch.manual_seed(cfg.seed) + torch.set_num_threads(cfg.num_threads) + os.environ['MKL_NUM_THREADS'] = str(cfg.num_threads) + os.environ['OMP_NUM_THREADS'] = str(cfg.num_threads) + logger.debug('Setting num threads to %d', cfg.num_threads) + set_efficient_attention_backend(cfg.efficient_attention_backend) + logger.debug('Setting efficient attention backend to %s', cfg.efficient_attention_backend) + if 'SLURM_JOB_ID' in os.environ: + tmpdir = Path('/scratch/slurm_tmpdir/' + os.environ['SLURM_JOB_ID']) + if tmpdir.exists(): + logger.info("Changing tmpdir to %s", tmpdir) + os.environ['TMPDIR'] = str(tmpdir) + + +@hydra_main(config_path='../config', config_name='config', version_base='1.1') +def main(cfg): + init_seed_and_system(cfg) + + # Setup logging both to XP specific folder, and to stderr. + log_name = '%s.log.{rank}' % cfg.execute_only if cfg.execute_only else 'solver.log.{rank}' + flashy.setup_logging(level=str(cfg.logging.level).upper(), log_name=log_name) + # Initialize distributed training, no need to specify anything when using Dora. + flashy.distrib.init() + solver = get_solver(cfg) + if cfg.show: + solver.show() + return + + if cfg.execute_only: + assert cfg.execute_inplace or cfg.continue_from is not None, \ + "Please explicitly specify the checkpoint to continue from with continue_from= " + \ + "when running with execute_only or set execute_inplace to True." + solver.restore(replay_metrics=False) # load checkpoint + solver.run_one_stage(cfg.execute_only) + return + + return solver.run() + + +main.dora.dir = AudioCraftEnvironment.get_dora_dir() +main._base_cfg.slurm = get_slurm_parameters(main._base_cfg.slurm) + +if main.dora.shared is not None and not os.access(main.dora.shared, os.R_OK): + print("No read permission on dora.shared folder, ignoring it.", file=sys.stderr) + main.dora.shared = None + +if __name__ == '__main__': + main() diff --git a/audiocraft/utils/__init__.py b/audiocraft/utils/__init__.py index 0952fcc3f57e34b3747962e9ebd6fc57aeea63fa..75e25a0212f98e4a18d97c86c6cda225636a3215 100644 --- a/audiocraft/utils/__init__.py +++ b/audiocraft/utils/__init__.py @@ -3,3 +3,4 @@ # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +"""Utilities.""" diff --git a/audiocraft/utils/best_state.py b/audiocraft/utils/best_state.py new file mode 100644 index 0000000000000000000000000000000000000000..f5ad551432ad5cb0f83278b5d2100f9aa287958b --- /dev/null +++ b/audiocraft/utils/best_state.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from collections import defaultdict +import logging +import typing as tp + +import flashy +import torch + +from ..optim import ModuleDictEMA +from .utils import copy_state + + +logger = logging.getLogger(__name__) + + +class BestStateDictManager(flashy.state.StateDictSource): + """BestStateDictManager maintains a copy of best state_dict() for registered sources. + + BestStateDictManager has two main attributes: + states (dict): State dict of the registered StateDictSource. + param_ids (dict): Dict of parameter ids for registered states from ModuleDictEMA and other sources. + + When registering new sources, the BestStateDictManager will ensure two conflicting sources between + ModuleDictEMA and original modules are not both registered as it would otherwise create ambiguity about + what to consider for best state. + + Args: + device (torch.device or str): Device on which we keep the copy. + dtype (torch.dtype): Data type for the state parameters. + """ + def __init__(self, device: tp.Union[torch.device, str] = 'cpu', + dtype: tp.Optional[torch.dtype] = None): + self.device = device + self.states: dict = {} + self.param_ids: dict = defaultdict(dict) + self.dtype = dtype + + def _get_parameter_ids(self, state_dict): + return {id(p): name for name, p in state_dict.items() if isinstance(p, torch.Tensor)} + + def _validate_no_parameter_ids_overlap(self, name: str, param_ids: dict): + for registered_name, registered_param_ids in self.param_ids.items(): + if registered_name != name: + overlap = set.intersection(registered_param_ids.keys(), param_ids.keys()) + assert len(overlap) == 0, f"Found {len(overlap)} / {len(param_ids.keys())} overlapping parameters" + f" in {name} and already registered {registered_name}: {' '.join(overlap)}" + + def update(self, name: str, source: flashy.state.StateDictSource): + if name not in self.states: + raise ValueError(f"{name} missing from registered states.") + self.states[name] = copy_state(source.state_dict(), device=self.device, dtype=self.dtype) + + def register(self, name: str, source: flashy.state.StateDictSource): + if name in self.states: + raise ValueError(f"{name} already present in states.") + # Registering parameter ids for EMA and non-EMA states allows us to check that + # there is no overlap that would create ambiguity about how to handle the best state + param_ids = self._get_parameter_ids(source.state_dict()) + if isinstance(source, ModuleDictEMA): + logger.debug(f"Registering to best state: ModuleDictEMA '{name}' with {len(param_ids)} params") + self._validate_no_parameter_ids_overlap(name, param_ids) + self.param_ids[name] = param_ids + else: + logger.debug(f"Registering to best state: StateDictSource '{name}' with {len(param_ids)} params") + self._validate_no_parameter_ids_overlap('base', param_ids) + self.param_ids['base'].update(param_ids) + # Register state + self.states[name] = copy_state(source.state_dict(), device=self.device, dtype=self.dtype) + + def state_dict(self) -> flashy.state.StateDict: + return self.states + + def load_state_dict(self, state: flashy.state.StateDict): + for name, sub_state in state.items(): + for k, v in sub_state.items(): + self.states[name][k].copy_(v) diff --git a/audiocraft/utils/cache.py b/audiocraft/utils/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..f7f82064e8f43b86af1071cab4d967cca9b5bd86 --- /dev/null +++ b/audiocraft/utils/cache.py @@ -0,0 +1,323 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from concurrent.futures import ThreadPoolExecutor +from collections import deque +from functools import partial +from hashlib import sha1 +import logging +from pathlib import Path +import sys +import typing as tp +import zipfile + +import flashy +import torch + + +logger = logging.getLogger(__name__) + + +def get_full_embed(full_embed: torch.Tensor, x: tp.Any, idx: int, device: tp.Union[str, torch.device]) -> torch.Tensor: + """Utility function for the EmbeddingCache, returning the full embedding without any chunking. + This method can be used in case there is no need in extracting a chunk of the full embedding + read from the cache. + + Args: + full_embed (torch.Tensor): The full embedding. + x (any): Batch object from which the full embedding is derived. + idx (torch.Tensor): Index of object to consider in the batch object. + Returns: + full_embed (torch.Tensor): The full embedding + """ + return full_embed.to(device) + + +class EmbeddingCache: + """Cache around embeddings computation for faster execution. + The EmbeddingCache is storing pre-computed embeddings on disk and provides a simple API + to retrieve the pre-computed embeddings on full inputs and extract only a given chunk + using a user-provided function. When the cache is warm (all embeddings are pre-computed), + the EmbeddingCache allows for faster training as it removes the need of computing the embeddings. + Additionally, it provides in-memory cache around the loaded embeddings to limit IO footprint + and synchronization points in the forward calls. + + Args: + cache_path (Path): Path to folder where all pre-computed embeddings are saved on disk. + device (str or torch.device): Device on which the embedding is returned. + compute_embed_fn (callable[[Path, any, int], torch.Tensor], optional): Function to compute + the embedding from a given object and path. This user provided function can compute the + embedding from the provided object or using the provided path as entry point. The last parameter + specify the index corresponding to the current embedding in the object that can represent batch metadata. + extract_embed_fn (callable[[torch.Tensor, any, int], torch.Tensor], optional): Function to extract + the desired embedding chunk from the full embedding loaded from the cache. The last parameter + specify the index corresponding to the current embedding in the object that can represent batch metadata. + If not specified, will return the full embedding unmodified. + """ + def __init__(self, cache_path: tp.Union[str, Path], device: tp.Union[str, torch.device], + compute_embed_fn: tp.Callable[[Path, tp.Any, int], torch.Tensor], + extract_embed_fn: tp.Optional[tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor]] = None): + self.cache_path = Path(cache_path) + self.device = device + self._compute_embed_fn = compute_embed_fn + self._extract_embed_fn: tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor] + if extract_embed_fn is not None: + self._extract_embed_fn = extract_embed_fn + else: + self._extract_embed_fn = partial(get_full_embed, device=device) + if self.cache_path is not None: + self.cache_path.mkdir(exist_ok=True, parents=True) + logger.info(f"Cache instantiated at: {self.cache_path}") + self.pool = ThreadPoolExecutor(8) + self.pool.__enter__() + self._current_batch_cache: dict = {} + self._memory_cache: dict = {} + + def _get_cache_path(self, path: tp.Union[Path, str]): + """Get cache path for the given file path.""" + sig = sha1(str(path).encode()).hexdigest() + return self.cache_path / sig + + @staticmethod + def _get_full_embed_from_cache(cache: Path): + """Loads full pre-computed embedding from the cache.""" + try: + embed = torch.load(cache, 'cpu') + except Exception as exc: + logger.error("Error loading %s: %r", cache, exc) + embed = None + return embed + + def get_embed_from_cache(self, paths: tp.List[Path], x: tp.Any) -> torch.Tensor: + """Get embedding from cache, computing and storing it to cache if not already cached. + The EmbeddingCache first tries to load the embedding from the in-memory cache + containing the pre-computed chunks populated through `populate_embed_cache`. + If not found, the full embedding is computed and stored on disk to be later accessed + to populate the in-memory cache, and the desired embedding chunk is extracted and returned. + + Args: + paths (list[Path or str]): List of paths from where the embeddings can be loaded. + x (any): Object from which the embedding is extracted. + """ + embeds = [] + for idx, path in enumerate(paths): + cache = self._get_cache_path(path) + if cache in self._current_batch_cache: + embed = self._current_batch_cache[cache] + else: + full_embed = self._compute_embed_fn(path, x, idx) + try: + with flashy.utils.write_and_rename(cache, pid=True) as f: + torch.save(full_embed.cpu(), f) + except Exception as exc: + logger.error('Error saving embed %s (%s): %r', cache, full_embed.shape, exc) + else: + logger.info('New embed cache saved: %s (%s)', cache, full_embed.shape) + embed = self._extract_embed_fn(full_embed, x, idx) + embeds.append(embed) + embed = torch.stack(embeds, dim=0) + return embed + + def populate_embed_cache(self, paths: tp.List[Path], x: tp.Any) -> None: + """Populate in-memory caches for embeddings reading from the embeddings stored on disk. + The in-memory caches consist in a cache for the full embedding and another cache for the + final embedding chunk. Such caches are used to limit the IO access when computing the actual embeddings + and reduce the IO footprint and synchronization points during forward passes. + + Args: + paths (list[Path]): List of paths from where the embeddings can be loaded. + x (any): Object from which the embedding is extracted. + """ + self._current_batch_cache.clear() + if self.cache_path is not None: + futures: list = [] + for path in paths: + assert path is not None, "Path is required for computation from cache" + cache = self._get_cache_path(path) + if cache in self._memory_cache or not cache.exists(): + futures.append(None) + else: + futures.append(self.pool.submit(EmbeddingCache._get_full_embed_from_cache, cache)) + for idx, (path, future) in enumerate(zip(paths, futures)): + assert path is not None + cache = self._get_cache_path(path) + full_embed = None + if future is None: + if cache in self._memory_cache: + full_embed = self._memory_cache[cache] + else: + full_embed = future.result() + if full_embed is not None: + self._memory_cache[cache] = full_embed + full_embed = full_embed.to(self.device) + if full_embed is not None: + embed = self._extract_embed_fn(full_embed, x, idx) + self._current_batch_cache[cache] = embed + + +class CachedBatchWriter: + """Write pre computed caches for mini batches. This can + make loading a lot more efficient depending on your filesystem. + + Args: + cache_folder (Path): folder in which the cached minibatches + will be stored. + + Inside cache folder, the structure is the following: + `epoch_number / update_number.zip` + And the zip file contains one entry per batch item. + + It is possible to use the cache with a batch size smaller than + created with but obviously not larger. Make sure to call the + `start_epoch(epoch)` method for indicating changes of epochs. + + See the grid `audiocraft/grids/musicgen/musicgen_warmup_cache.py` + for an example of how to warmup the cache. + """ + def __init__(self, cache_folder: Path): + self.cache_folder = cache_folder + self._current_epoch: tp.Optional[int] = None + self._current_index = 0 + + def start_epoch(self, epoch: int): + """Call at the beginning of each epoch. + """ + self._current_epoch = epoch + self._current_index = 0 + self._zip_path.parent.mkdir(exist_ok=True, parents=True) + + @staticmethod + def _get_zip_path(cache_folder: Path, epoch: int, index: int): + return cache_folder / f"{epoch:05d}" / f"{index:06d}.zip" + + @property + def _zip_path(self): + assert self._current_epoch is not None + return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, self._current_index) + + def save(self, *content): + """Save one mini batch. This function is distributed-aware + and will automatically merge all the items from the different + workers. + """ + all_contents = [] + for rank in range(flashy.distrib.world_size()): + their_content = flashy.distrib.broadcast_object(content, src=rank) + all_contents.append(their_content) + + if flashy.distrib.is_rank_zero(): + idx = 0 + with flashy.utils.write_and_rename(self._zip_path) as tmp: + with zipfile.ZipFile(tmp, 'w') as zf: + for content in all_contents: + for vals in zip(*content): + with zf.open(f'{idx}', 'w') as f: # type: ignore + torch.save(vals, f) + idx += 1 + flashy.distrib.barrier() + self._current_index += 1 + + +class CachedBatchLoader: + """Loader for cached mini-batches dumped with `CachedBatchWriter`. + + Args: + cache_folder (Path): folder in which the cached minibatches are stored. + batch_size (int): batch size (per GPU) expected. + num_workers (int): number of workers to use for loading. + min_length (int): minimum expected length for each epoch. If some + mini-batches are missing, and error is raised. + + This is iterable just like a regular DataLoader. + """ + + def __init__(self, cache_folder: Path, batch_size: int, + num_workers: int = 10, min_length: int = 1): + self.cache_folder = cache_folder + self.batch_size = batch_size + self.num_workers = num_workers + self.min_length = min_length + self._current_epoch: tp.Optional[int] = None + self.sampler = None # for compatibility with the regular DataLoader + + def __len__(self): + path = CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch or 0, 0).parent + return len([p for p in path.iterdir() if p.suffix == ".zip"]) + + def start_epoch(self, epoch: int): + """Call at the beginning of each epoch. + """ + self._current_epoch = epoch + + def _zip_path(self, index: int): + assert self._current_epoch is not None + return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, index) + + def _load_one(self, index: int): + zip_path = self._zip_path(index) + if not zip_path.exists(): + if index < self.min_length: + raise RuntimeError(f"Cache should have at least {self.min_length} batches, but {index} doesn't exist") + + return None + mode = "rb" if sys.version_info >= (3, 9) else "r" + try: + with zipfile.ZipFile(zip_path, 'r') as zf: + rank = flashy.distrib.rank() + world_size = flashy.distrib.world_size() + root = zipfile.Path(zf) + items = list(root.iterdir()) + total_batch_size = self.batch_size * world_size + if len(items) < total_batch_size: + raise RuntimeError( + f"The cache can handle a max batch size of {len(items)}, " + f"but {total_batch_size} is needed.") + start = rank * self.batch_size + items = items[start: start + self.batch_size] + assert len(items) == self.batch_size + entries = [] + entries = [torch.load(item.open(mode), 'cpu') for item in items] # type: ignore + transposed = zip(*entries) + out = [] + for part in transposed: + assert len(part) > 0 + if isinstance(part[0], torch.Tensor): + out.append(torch.stack(part)) + else: + out.append(part) + return out + except Exception: + logger.error("Error when reading zip path %s", zip_path) + raise + + def __iter__(self): + """This will yields tuples, exactly as provided to the + `CachedBatchWriter.save` method. + """ + pool = ThreadPoolExecutor(self.num_workers) + next_index = 0 + queue = deque() + + def _get_next(): + nonlocal next_index + r = queue.popleft().result() + if r is None: + return None + else: + queue.append(pool.submit(self._load_one, next_index)) + next_index += 1 + return r + + with pool: + # fill the buffer of fetching jobs. + for _ in range(2 * self.num_workers): + queue.append(pool.submit(self._load_one, next_index)) + next_index += 1 + while True: + batch = _get_next() + if batch is None: + return + yield batch diff --git a/audiocraft/utils/checkpoint.py b/audiocraft/utils/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..f6f871837e09c5cc7832b85b0d80b84f59e87ca0 --- /dev/null +++ b/audiocraft/utils/checkpoint.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum +import logging +from pathlib import Path +import re +import typing as tp + +import flashy +import torch + +from ..environment import AudioCraftEnvironment + + +logger = logging.getLogger(__name__) + + +class CheckpointSource(Enum): + CURRENT_XP = "current_xp" + PRETRAINED = "pretrained" + OTHER = "other" + + +def checkpoint_name(name: tp.Optional[str] = None, rank: tp.Optional[int] = None, use_fsdp: bool = False) -> str: + """Checkpoint name formatted for all use in AudioCraft codebase and has the following format: + `checkpoint_.th(.)`. By convention, name is expected to be empty for last checkpoint, + 'best' for the best checkpoint or the epoch number. + + Args: + name (str, optional): Name suffix for the checkpoint file stem. + rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided. + use_fsdp (bool): Whether the calling solver relies on FSDP. + Returns: + str: The checkpoint name. + """ + suffix = '' + if rank is None: + rank = flashy.distrib.rank() + if rank > 0 and use_fsdp: + suffix = '.' + str(rank) + name_part = '' + if name is not None: + name_part = f'_{name}' + return f'checkpoint{name_part}.th{suffix}' + + +def is_sharded_checkpoint(path: Path) -> bool: + """Whether the checkpoint at the given path corresponds to a sharded checkpoint across rank.""" + return re.search(r'\.th\.\d+$', path.name) is not None + + +def resolve_checkpoint_path(sig_or_path: tp.Union[Path, str], name: tp.Optional[str] = None, + use_fsdp: bool = False) -> tp.Optional[Path]: + """Resolve a given checkpoint path for a provided dora sig or path. + + Args: + sig_or_path (Path or str): Checkpoint path or dora signature. + name (str, optional): Name suffix for the checkpoint file stem. + rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided. + use_fsdp (bool): Whether the calling solver relies on FSDP. + Returns: + Path, optional: Resolved checkpoint path, if it exists. + """ + from audiocraft import train + xps_root = train.main.dora.dir / 'xps' + sig_or_path = str(sig_or_path) + if sig_or_path.startswith('//sig/'): + sig = sig_or_path[len('//sig/'):] + path = xps_root / sig + else: + path = Path(sig_or_path) + path = AudioCraftEnvironment.resolve_reference_path(path) + + if path.is_dir(): + path = path / checkpoint_name(name, use_fsdp=use_fsdp) + + if path.exists(): + return path + else: + return None + + +def load_checkpoint(checkpoint_path: Path, is_sharded: bool = False) -> tp.Any: + """Load state from checkpoints at the specified checkpoint path.""" + if is_sharded: + rank0_checkpoint_path = checkpoint_path.parent / checkpoint_name(use_fsdp=False) + if rank0_checkpoint_path.exists(): + check_sharded_checkpoint(checkpoint_path, rank0_checkpoint_path) + state = torch.load(checkpoint_path, 'cpu') + logger.info("Checkpoint loaded from %s", checkpoint_path) + return state + + +def save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None: + """Save state to disk to the specified checkpoint_path.""" + _safe_save_checkpoint(state, checkpoint_path, is_sharded) + logger.info("Checkpoint saved to %s", checkpoint_path) + + +def flush_stale_checkpoints(checkpoint_path: Path, keep_last: tp.Optional[int] = None) -> None: + """Flush checkpoints to only keep last N checkpoints.""" + if keep_last is None or keep_last <= 0: + return + checkpoint_dir = checkpoint_path.parent + suffix = '' + if flashy.distrib.rank() > 0: + suffix = f'.{flashy.distrib.rank()}' + checkpoint_files_with_epoch = [] + for path in Path(checkpoint_dir).glob(f'checkpoint_*.th{suffix}'): + epoch_part = path.name.split('.', 1)[0].split('_', 1)[1] + if epoch_part.isdigit(): + checkpoint_files_with_epoch.append((path, int(epoch_part))) + checkpoint_files = [path for path, _ in list(sorted(checkpoint_files_with_epoch, key=lambda t: t[1]))] + total_to_flush = max(0, len(checkpoint_files) - keep_last) + files_to_flush = checkpoint_files[:total_to_flush] + for path in files_to_flush: + logger.debug("Removing checkpoint: %s", str(path)) + path.unlink(missing_ok=True) + + +def check_sharded_checkpoint(checkpoint_path: Path, rank0_checkpoint_path: Path) -> None: + """Check sharded checkpoint state, ensuring the checkpoints are not corrupted.""" + # Finish the work of a previous run that got interrupted while dumping. + old_path = Path(str(checkpoint_path) + '.old') + if old_path.exists(): + raise RuntimeError( + f"Old checkpoint {old_path} from previous version of this code exist, cannot safely proceed.") + token = Path(str(rank0_checkpoint_path) + '.tmp.done') + tmp_path = Path(str(checkpoint_path) + '.tmp') + if token.exists(): + if tmp_path.exists(): + tmp_path.rename(checkpoint_path) + flashy.distrib.barrier() + if flashy.distrib.is_rank_zero() and token.exists(): + token.unlink() + + +def _safe_save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None: + """Save checkpoints in a safe manner even with when sharded checkpoints across nodes.""" + def _barrier_if_sharded(): + if is_sharded: + flashy.distrib.barrier() + + if flashy.distrib.is_rank_zero(): + token = Path(str(checkpoint_path) + '.tmp.done') + if token.exists(): + token.unlink() + _barrier_if_sharded() + with flashy.utils.write_and_rename(checkpoint_path) as f: + torch.save(state, f) + _barrier_if_sharded() + if flashy.distrib.is_rank_zero(): + token.touch() + _barrier_if_sharded() + _barrier_if_sharded() + if flashy.distrib.rank() == 0: + token.unlink() diff --git a/audiocraft/utils/cluster.py b/audiocraft/utils/cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..3380d031739d473fb859c76b9c25350f47fa77e8 --- /dev/null +++ b/audiocraft/utils/cluster.py @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utility functions for SLURM configuration and cluster settings. +""" + +from enum import Enum +import os +import socket +import typing as tp + +import omegaconf + + +class ClusterType(Enum): + AWS = "aws" + FAIR = "fair" + RSC = "rsc" + LOCAL_DARWIN = "darwin" + DEFAULT = "default" # used for any other cluster. + + +def _guess_cluster_type() -> ClusterType: + uname = os.uname() + fqdn = socket.getfqdn() + if uname.sysname == "Linux" and (uname.release.endswith("-aws") or ".ec2" in fqdn): + return ClusterType.AWS + + if fqdn.endswith(".fair"): + return ClusterType.FAIR + + if fqdn.endswith(".facebook.com"): + return ClusterType.RSC + + if uname.sysname == "Darwin": + return ClusterType.LOCAL_DARWIN + + return ClusterType.DEFAULT + + +def get_cluster_type( + cluster_type: tp.Optional[ClusterType] = None, +) -> tp.Optional[ClusterType]: + if cluster_type is None: + return _guess_cluster_type() + + return cluster_type + + +def get_slurm_parameters( + cfg: omegaconf.DictConfig, cluster_type: tp.Optional[ClusterType] = None +) -> omegaconf.DictConfig: + """Update SLURM parameters in configuration based on cluster type. + If the cluster type is not specify, it infers it automatically. + """ + from ..environment import AudioCraftEnvironment + cluster_type = get_cluster_type(cluster_type) + # apply cluster-specific adjustments + if cluster_type == ClusterType.AWS: + cfg["mem_per_gpu"] = None + cfg["constraint"] = None + cfg["setup"] = [] + elif cluster_type == ClusterType.RSC: + cfg["mem_per_gpu"] = None + cfg["setup"] = [] + cfg["constraint"] = None + cfg["partition"] = "learn" + slurm_exclude = AudioCraftEnvironment.get_slurm_exclude() + if slurm_exclude is not None: + cfg["exclude"] = slurm_exclude + return cfg diff --git a/audiocraft/utils/deadlock.py b/audiocraft/utils/deadlock.py new file mode 100644 index 0000000000000000000000000000000000000000..8abd1bbeea5909e664cf816c020bd7c37effdb66 --- /dev/null +++ b/audiocraft/utils/deadlock.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +from queue import Queue, Empty +import signal +import sys +import threading +import traceback + +logger = logging.getLogger(__name__) + + +class DeadlockDetect: + def __init__(self, use: bool = False, timeout: float = 120.): + self.use = use + self.timeout = timeout + self._queue: Queue = Queue() + + def update(self, stage: str): + if self.use: + self._queue.put(stage) + + def __enter__(self): + if self.use: + self._thread = threading.Thread(target=self._detector_thread) + self._thread.start() + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.use: + self._queue.put(None) + self._thread.join() + + def _detector_thread(self): + logger.debug("Deadlock detector started") + last_stage = "init" + while True: + try: + stage = self._queue.get(timeout=self.timeout) + except Empty: + break + if stage is None: + logger.debug("Exiting deadlock detector thread") + return + else: + last_stage = stage + logger.error("Deadlock detector timed out, last stage was %s", last_stage) + for th in threading.enumerate(): + print(th, file=sys.stderr) + traceback.print_stack(sys._current_frames()[th.ident]) + print(file=sys.stderr) + sys.stdout.flush() + sys.stderr.flush() + os.kill(os.getpid(), signal.SIGKILL) diff --git a/audiocraft/utils/export.py b/audiocraft/utils/export.py index b513b52267f7bf5aae09282c15b0a2e20c8a8fee..28b214017d9ac23934b67e8254a96131cefa6501 100644 --- a/audiocraft/utils/export.py +++ b/audiocraft/utils/export.py @@ -11,46 +11,69 @@ Utility to export a training checkpoint to a lightweight release checkpoint. from pathlib import Path import typing as tp -from omegaconf import OmegaConf, DictConfig +from omegaconf import OmegaConf import torch +from audiocraft import __version__ -def _clean_lm_cfg(cfg: DictConfig): - OmegaConf.set_struct(cfg, False) - # This used to be set automatically in the LM solver, need a more robust solution - # for the future. - cfg['transformer_lm']['card'] = 2048 - cfg['transformer_lm']['n_q'] = 4 - # Experimental params no longer supported. - bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters', - 'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop'] - for name in bad_params: - del cfg['transformer_lm'][name] - OmegaConf.set_struct(cfg, True) - return cfg - - -def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]): - sig = Path(checkpoint_path).parent.name - assert len(sig) == 8, "Not a valid Dora signature" + +def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]): + """Export only the best state from the given EnCodec checkpoint. This + should be used if you trained your own EnCodec model. + """ pkg = torch.load(checkpoint_path, 'cpu') new_pkg = { - 'best_state': pkg['ema']['state']['model'], + 'best_state': pkg['best_state']['model'], 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), + 'version': __version__, + 'exported': True, } - out_file = Path(out_folder) / f'{sig}.th' + Path(out_file).parent.mkdir(exist_ok=True, parents=True) torch.save(new_pkg, out_file) return out_file -def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]): - sig = Path(checkpoint_path).parent.name - assert len(sig) == 8, "Not a valid Dora signature" +def export_pretrained_compression_model(pretrained_encodec: str, out_file: tp.Union[Path, str]): + """Export a compression model (potentially EnCodec) from a pretrained model. + This is required for packaging the audio tokenizer along a MusicGen or AudioGen model. + Do not include the //pretrained/ prefix. For instance if you trained a model + with `facebook/encodec_32khz`, just put that as a name. Same for `dac_44khz`. + + In that case, this will not actually include a copy of the model, simply the reference + to the model used. + """ + if Path(pretrained_encodec).exists(): + pkg = torch.load(pretrained_encodec) + assert 'best_state' in pkg + assert 'xp.cfg' in pkg + assert 'version' in pkg + assert 'exported' in pkg + else: + pkg = { + 'pretrained': pretrained_encodec, + 'exported': True, + 'version': __version__, + } + Path(out_file).parent.mkdir(exist_ok=True, parents=True) + torch.save(pkg, out_file) + + +def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]): + """Export only the best state from the given MusicGen or AudioGen checkpoint. + """ pkg = torch.load(checkpoint_path, 'cpu') + if pkg['fsdp_best_state']: + best_state = pkg['fsdp_best_state']['model'] + else: + assert pkg['best_state'] + best_state = pkg['best_state']['model'] new_pkg = { - 'best_state': pkg['fsdp_best_state']['model'], - 'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg'])) + 'best_state': best_state, + 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), + 'version': __version__, + 'exported': True, } - out_file = Path(out_folder) / f'{sig}.th' + + Path(out_file).parent.mkdir(exist_ok=True, parents=True) torch.save(new_pkg, out_file) return out_file diff --git a/audiocraft/utils/export_legacy.py b/audiocraft/utils/export_legacy.py new file mode 100644 index 0000000000000000000000000000000000000000..367c3f3c9f95ae59a95edbb60b470e03cc842fbb --- /dev/null +++ b/audiocraft/utils/export_legacy.py @@ -0,0 +1,70 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Legacy functions used at the time of the first release, kept for referencd. +""" + +from pathlib import Path +import typing as tp + +from omegaconf import OmegaConf, DictConfig +import torch + +from audiocraft import __version__ + + +def _clean_lm_cfg(cfg: DictConfig): + OmegaConf.set_struct(cfg, False) + # This used to be set automatically in the LM solver, need a more robust solution + # for the future. + cfg['transformer_lm']['card'] = 2048 + n_q = 4 + stereo_cfg = getattr(cfg, 'interleave_stereo_codebooks', None) + if stereo_cfg is not None and stereo_cfg.use: + if 'downsample' in stereo_cfg: + del stereo_cfg['downsample'] + n_q = 8 + cfg['transformer_lm']['n_q'] = n_q + # Experimental params no longer supported. + bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters', + 'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop'] + for name in bad_params: + del cfg['transformer_lm'][name] + OmegaConf.set_struct(cfg, True) + return cfg + + +def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]): + pkg = torch.load(checkpoint_path, 'cpu') + new_pkg = { + 'best_state': pkg['ema']['state']['model'], + 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), + # The following params were NOT exported for the first release of MusicGen. + 'version': __version__, + 'exported': True, + } + Path(out_file).parent.mkdir(exist_ok=True, parents=True) + torch.save(new_pkg, out_file) + return out_file + + +def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]): + pkg = torch.load(checkpoint_path, 'cpu') + if pkg['fsdp_best_state']: + best_state = pkg['fsdp_best_state']['model'] + else: + best_state = pkg['best_state']['model'] + new_pkg = { + 'best_state': best_state, + 'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg'])), + # The following params were NOT exported for the first release of MusicGen. + 'version': __version__, + 'exported': True, + } + Path(out_file).parent.mkdir(exist_ok=True, parents=True) + torch.save(new_pkg, out_file) + return out_file diff --git a/audiocraft/utils/profiler.py b/audiocraft/utils/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..b45b6d15910b50305c7b212c089ffad3c25b324d --- /dev/null +++ b/audiocraft/utils/profiler.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import typing as tp + +import dora +import torch + + +logger = logging.getLogger(__name__) + + +class Profiler: + """Context manager wrapper for xformers profiler. + """ + def __init__(self, module: torch.nn.Module, enabled: bool = False): + self.profiler: tp.Optional[tp.Any] = None + if enabled: + from xformers.profiler import profile + output_dir = dora.get_xp().folder / 'profiler_data' + logger.info("Profiling activated, results with be saved to %s", output_dir) + self.profiler = profile(output_dir=output_dir, module=module) + + def step(self): + if self.profiler is not None: + self.profiler.step() # type: ignore + + def __enter__(self): + if self.profiler is not None: + return self.profiler.__enter__() # type: ignore + + def __exit__(self, exc_type, exc_value, exc_tb): + if self.profiler is not None: + return self.profiler.__exit__(exc_type, exc_value, exc_tb) # type: ignore diff --git a/audiocraft/utils/samples/__init__.py b/audiocraft/utils/samples/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa --- /dev/null +++ b/audiocraft/utils/samples/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/audiocraft/utils/samples/manager.py b/audiocraft/utils/samples/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..bf0fb21b2d2867c03f7cce6f27d9524fdb89b51d --- /dev/null +++ b/audiocraft/utils/samples/manager.py @@ -0,0 +1,386 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +API that can manage the storage and retrieval of generated samples produced by experiments. + +It offers the following benefits: +* Samples are stored in a consistent way across epoch +* Metadata about the samples can be stored and retrieved +* Can retrieve audio +* Identifiers are reliable and deterministic for prompted and conditioned samples +* Can request the samples for multiple XPs, grouped by sample identifier +* For no-input samples (not prompt and no conditions), samples across XPs are matched + by sorting their identifiers +""" + +from concurrent.futures import ThreadPoolExecutor +from dataclasses import asdict, dataclass +from functools import lru_cache +import hashlib +import json +import logging +from pathlib import Path +import re +import typing as tp +import unicodedata +import uuid + +import dora +import torch + +from ...data.audio import audio_read, audio_write + + +logger = logging.getLogger(__name__) + + +@dataclass +class ReferenceSample: + id: str + path: str + duration: float + + +@dataclass +class Sample: + id: str + path: str + epoch: int + duration: float + conditioning: tp.Optional[tp.Dict[str, tp.Any]] + prompt: tp.Optional[ReferenceSample] + reference: tp.Optional[ReferenceSample] + generation_args: tp.Optional[tp.Dict[str, tp.Any]] + + def __hash__(self): + return hash(self.id) + + def audio(self) -> tp.Tuple[torch.Tensor, int]: + return audio_read(self.path) + + def audio_prompt(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]: + return audio_read(self.prompt.path) if self.prompt is not None else None + + def audio_reference(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]: + return audio_read(self.reference.path) if self.reference is not None else None + + +class SampleManager: + """Audio samples IO handling within a given dora xp. + + The sample manager handles the dumping and loading logic for generated and + references samples across epochs for a given xp, providing a simple API to + store, retrieve and compare audio samples. + + Args: + xp (dora.XP): Dora experiment object. The XP contains information on the XP folder + where all outputs are stored and the configuration of the experiment, + which is useful to retrieve audio-related parameters. + map_reference_to_sample_id (bool): Whether to use the sample_id for all reference samples + instead of generating a dedicated hash id. This is useful to allow easier comparison + with ground truth sample from the files directly without having to read the JSON metadata + to do the mapping (at the cost of potentially dumping duplicate prompts/references + depending on the task). + """ + def __init__(self, xp: dora.XP, map_reference_to_sample_id: bool = False): + self.xp = xp + self.base_folder: Path = xp.folder / xp.cfg.generate.path + self.reference_folder = self.base_folder / 'reference' + self.map_reference_to_sample_id = map_reference_to_sample_id + self.samples: tp.List[Sample] = [] + self._load_samples() + + @property + def latest_epoch(self): + """Latest epoch across all samples.""" + return max(self.samples, key=lambda x: x.epoch).epoch if self.samples else 0 + + def _load_samples(self): + """Scan the sample folder and load existing samples.""" + jsons = self.base_folder.glob('**/*.json') + with ThreadPoolExecutor(6) as pool: + self.samples = list(pool.map(self._load_sample, jsons)) + + @staticmethod + @lru_cache(2**26) + def _load_sample(json_file: Path) -> Sample: + with open(json_file, 'r') as f: + data: tp.Dict[str, tp.Any] = json.load(f) + # fetch prompt data + prompt_data = data.get('prompt') + prompt = ReferenceSample(id=prompt_data['id'], path=prompt_data['path'], + duration=prompt_data['duration']) if prompt_data else None + # fetch reference data + reference_data = data.get('reference') + reference = ReferenceSample(id=reference_data['id'], path=reference_data['path'], + duration=reference_data['duration']) if reference_data else None + # build sample object + return Sample(id=data['id'], path=data['path'], epoch=data['epoch'], duration=data['duration'], + prompt=prompt, conditioning=data.get('conditioning'), reference=reference, + generation_args=data.get('generation_args')) + + def _init_hash(self): + return hashlib.sha1() + + def _get_tensor_id(self, tensor: torch.Tensor) -> str: + hash_id = self._init_hash() + hash_id.update(tensor.numpy().data) + return hash_id.hexdigest() + + def _get_sample_id(self, index: int, prompt_wav: tp.Optional[torch.Tensor], + conditions: tp.Optional[tp.Dict[str, str]]) -> str: + """Computes an id for a sample given its input data. + This id is deterministic if prompt and/or conditions are provided by using a sha1 hash on the input. + Otherwise, a random id of the form "noinput_{uuid4().hex}" is returned. + + Args: + index (int): Batch index, Helpful to differentiate samples from the same batch. + prompt_wav (torch.Tensor): Prompt used during generation. + conditions (dict[str, str]): Conditioning used during generation. + """ + # For totally unconditioned generations we will just use a random UUID. + # The function get_samples_for_xps will do a simple ordered match with a custom key. + if prompt_wav is None and not conditions: + return f"noinput_{uuid.uuid4().hex}" + + # Human readable portion + hr_label = "" + # Create a deterministic id using hashing + hash_id = self._init_hash() + hash_id.update(f"{index}".encode()) + if prompt_wav is not None: + hash_id.update(prompt_wav.numpy().data) + hr_label += "_prompted" + else: + hr_label += "_unprompted" + if conditions: + encoded_json = json.dumps(conditions, sort_keys=True).encode() + hash_id.update(encoded_json) + cond_str = "-".join([f"{key}={slugify(value)}" + for key, value in sorted(conditions.items())]) + cond_str = cond_str[:100] # some raw text might be too long to be a valid filename + cond_str = cond_str if len(cond_str) > 0 else "unconditioned" + hr_label += f"_{cond_str}" + else: + hr_label += "_unconditioned" + + return hash_id.hexdigest() + hr_label + + def _store_audio(self, wav: torch.Tensor, stem_path: Path, overwrite: bool = False) -> Path: + """Stores the audio with the given stem path using the XP's configuration. + + Args: + wav (torch.Tensor): Audio to store. + stem_path (Path): Path in sample output directory with file stem to use. + overwrite (bool): When False (default), skips storing an existing audio file. + Returns: + Path: The path at which the audio is stored. + """ + existing_paths = [ + path for path in stem_path.parent.glob(stem_path.stem + '.*') + if path.suffix != '.json' + ] + exists = len(existing_paths) > 0 + if exists and overwrite: + logger.warning(f"Overwriting existing audio file with stem path {stem_path}") + elif exists: + return existing_paths[0] + + audio_path = audio_write(stem_path, wav, **self.xp.cfg.generate.audio) + return audio_path + + def add_sample(self, sample_wav: torch.Tensor, epoch: int, index: int = 0, + conditions: tp.Optional[tp.Dict[str, str]] = None, prompt_wav: tp.Optional[torch.Tensor] = None, + ground_truth_wav: tp.Optional[torch.Tensor] = None, + generation_args: tp.Optional[tp.Dict[str, tp.Any]] = None) -> Sample: + """Adds a single sample. + The sample is stored in the XP's sample output directory, under a corresponding epoch folder. + Each sample is assigned an id which is computed using the input data. In addition to the + sample itself, a json file containing associated metadata is stored next to it. + + Args: + sample_wav (torch.Tensor): sample audio to store. Tensor of shape [channels, shape]. + epoch (int): current training epoch. + index (int): helpful to differentiate samples from the same batch. + conditions (dict[str, str], optional): conditioning used during generation. + prompt_wav (torch.Tensor, optional): prompt used during generation. Tensor of shape [channels, shape]. + ground_truth_wav (torch.Tensor, optional): reference audio where prompt was extracted from. + Tensor of shape [channels, shape]. + generation_args (dict[str, any], optional): dictionary of other arguments used during generation. + Returns: + Sample: The saved sample. + """ + sample_id = self._get_sample_id(index, prompt_wav, conditions) + reuse_id = self.map_reference_to_sample_id + prompt, ground_truth = None, None + if prompt_wav is not None: + prompt_id = sample_id if reuse_id else self._get_tensor_id(prompt_wav.sum(0, keepdim=True)) + prompt_duration = prompt_wav.shape[-1] / self.xp.cfg.sample_rate + prompt_path = self._store_audio(prompt_wav, self.base_folder / str(epoch) / 'prompt' / prompt_id) + prompt = ReferenceSample(prompt_id, str(prompt_path), prompt_duration) + if ground_truth_wav is not None: + ground_truth_id = sample_id if reuse_id else self._get_tensor_id(ground_truth_wav.sum(0, keepdim=True)) + ground_truth_duration = ground_truth_wav.shape[-1] / self.xp.cfg.sample_rate + ground_truth_path = self._store_audio(ground_truth_wav, self.base_folder / 'reference' / ground_truth_id) + ground_truth = ReferenceSample(ground_truth_id, str(ground_truth_path), ground_truth_duration) + sample_path = self._store_audio(sample_wav, self.base_folder / str(epoch) / sample_id, overwrite=True) + duration = sample_wav.shape[-1] / self.xp.cfg.sample_rate + sample = Sample(sample_id, str(sample_path), epoch, duration, conditions, prompt, ground_truth, generation_args) + self.samples.append(sample) + with open(sample_path.with_suffix('.json'), 'w') as f: + json.dump(asdict(sample), f, indent=2) + return sample + + def add_samples(self, samples_wavs: torch.Tensor, epoch: int, + conditioning: tp.Optional[tp.List[tp.Dict[str, tp.Any]]] = None, + prompt_wavs: tp.Optional[torch.Tensor] = None, + ground_truth_wavs: tp.Optional[torch.Tensor] = None, + generation_args: tp.Optional[tp.Dict[str, tp.Any]] = None) -> tp.List[Sample]: + """Adds a batch of samples. + The samples are stored in the XP's sample output directory, under a corresponding + epoch folder. Each sample is assigned an id which is computed using the input data and their batch index. + In addition to the sample itself, a json file containing associated metadata is stored next to it. + + Args: + sample_wavs (torch.Tensor): Batch of audio wavs to store. Tensor of shape [batch_size, channels, shape]. + epoch (int): Current training epoch. + conditioning (list of dict[str, str], optional): List of conditions used during generation, + one per sample in the batch. + prompt_wavs (torch.Tensor, optional): Prompts used during generation. Tensor of shape + [batch_size, channels, shape]. + ground_truth_wav (torch.Tensor, optional): Reference audio where prompts were extracted from. + Tensor of shape [batch_size, channels, shape]. + generation_args (dict[str, Any], optional): Dictionary of other arguments used during generation. + Returns: + samples (list of Sample): The saved audio samples with prompts, ground truth and metadata. + """ + samples = [] + for idx, wav in enumerate(samples_wavs): + prompt_wav = prompt_wavs[idx] if prompt_wavs is not None else None + gt_wav = ground_truth_wavs[idx] if ground_truth_wavs is not None else None + conditions = conditioning[idx] if conditioning is not None else None + samples.append(self.add_sample(wav, epoch, idx, conditions, prompt_wav, gt_wav, generation_args)) + return samples + + def get_samples(self, epoch: int = -1, max_epoch: int = -1, exclude_prompted: bool = False, + exclude_unprompted: bool = False, exclude_conditioned: bool = False, + exclude_unconditioned: bool = False) -> tp.Set[Sample]: + """Returns a set of samples for this XP. Optionally, you can filter which samples to obtain. + Please note that existing samples are loaded during the manager's initialization, and added samples through this + manager are also tracked. Any other external changes are not tracked automatically, so creating a new manager + is the only way detect them. + + Args: + epoch (int): If provided, only return samples corresponding to this epoch. + max_epoch (int): If provided, only return samples corresponding to the latest epoch that is <= max_epoch. + exclude_prompted (bool): If True, does not include samples that used a prompt. + exclude_unprompted (bool): If True, does not include samples that did not use a prompt. + exclude_conditioned (bool): If True, excludes samples that used conditioning. + exclude_unconditioned (bool): If True, excludes samples that did not use conditioning. + Returns: + Samples (set of Sample): The retrieved samples matching the provided filters. + """ + if max_epoch >= 0: + samples_epoch = max(sample.epoch for sample in self.samples if sample.epoch <= max_epoch) + else: + samples_epoch = self.latest_epoch if epoch < 0 else epoch + samples = { + sample + for sample in self.samples + if ( + (sample.epoch == samples_epoch) and + (not exclude_prompted or sample.prompt is None) and + (not exclude_unprompted or sample.prompt is not None) and + (not exclude_conditioned or not sample.conditioning) and + (not exclude_unconditioned or sample.conditioning) + ) + } + return samples + + +def slugify(value: tp.Any, allow_unicode: bool = False): + """Process string for safer file naming. + + Taken from https://github.com/django/django/blob/master/django/utils/text.py + + Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated + dashes to single dashes. Remove characters that aren't alphanumerics, + underscores, or hyphens. Convert to lowercase. Also strip leading and + trailing whitespace, dashes, and underscores. + """ + value = str(value) + if allow_unicode: + value = unicodedata.normalize("NFKC", value) + else: + value = ( + unicodedata.normalize("NFKD", value) + .encode("ascii", "ignore") + .decode("ascii") + ) + value = re.sub(r"[^\w\s-]", "", value.lower()) + return re.sub(r"[-\s]+", "-", value).strip("-_") + + +def _match_stable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> tp.Dict[str, tp.List[Sample]]: + # Create a dictionary of stable id -> sample per XP + stable_samples_per_xp = [{ + sample.id: sample for sample in samples + if sample.prompt is not None or sample.conditioning + } for samples in samples_per_xp] + # Set of all stable ids + stable_ids = {id for samples in stable_samples_per_xp for id in samples.keys()} + # Dictionary of stable id -> list of samples. If an XP does not have it, assign None + stable_samples = {id: [xp.get(id) for xp in stable_samples_per_xp] for id in stable_ids} + # Filter out ids that contain None values (we only want matched samples after all) + # cast is necessary to avoid mypy linter errors. + return {id: tp.cast(tp.List[Sample], samples) for id, samples in stable_samples.items() if None not in samples} + + +def _match_unstable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> tp.Dict[str, tp.List[Sample]]: + # For unstable ids, we use a sorted list since we'll match them in order + unstable_samples_per_xp = [[ + sample for sample in sorted(samples, key=lambda x: x.id) + if sample.prompt is None and not sample.conditioning + ] for samples in samples_per_xp] + # Trim samples per xp so all samples can have a match + min_len = min([len(samples) for samples in unstable_samples_per_xp]) + unstable_samples_per_xp = [samples[:min_len] for samples in unstable_samples_per_xp] + # Dictionary of index -> list of matched samples + return { + f'noinput_{i}': [samples[i] for samples in unstable_samples_per_xp] for i in range(min_len) + } + + +def get_samples_for_xps(xps: tp.List[dora.XP], **kwargs) -> tp.Dict[str, tp.List[Sample]]: + """Gets a dictionary of matched samples across the given XPs. + Each dictionary entry maps a sample id to a list of samples for that id. The number of samples per id + will always match the number of XPs provided and will correspond to each XP in the same order given. + In other words, only samples that can be match across all provided XPs will be returned + in order to satisfy this rule. + + There are two types of ids that can be returned: stable and unstable. + * Stable IDs are deterministic ids that were computed by the SampleManager given a sample's inputs + (prompts/conditioning). This is why we can match them across XPs. + * Unstable IDs are of the form "noinput_{idx}" and are generated on-the-fly, in order to map samples + that used non-deterministic, random ids. This is the case for samples that did not use prompts or + conditioning for their generation. This function will sort these samples by their id and match them + by their index. + + Args: + xps: a list of XPs to match samples from. + start_epoch (int): If provided, only return samples corresponding to this epoch or newer. + end_epoch (int): If provided, only return samples corresponding to this epoch or older. + exclude_prompted (bool): If True, does not include samples that used a prompt. + exclude_unprompted (bool): If True, does not include samples that did not use a prompt. + exclude_conditioned (bool): If True, excludes samples that used conditioning. + exclude_unconditioned (bool): If True, excludes samples that did not use conditioning. + """ + managers = [SampleManager(xp) for xp in xps] + samples_per_xp = [manager.get_samples(**kwargs) for manager in managers] + stable_samples = _match_stable_samples(samples_per_xp) + unstable_samples = _match_unstable_samples(samples_per_xp) + return dict(stable_samples, **unstable_samples) diff --git a/audiocraft/utils/utils.py b/audiocraft/utils/utils.py index 86e1448d065fa182ca69aae00d2f2a7eea55d8a4..2c5799f8bc4ee07dd8d60d6afe67fbc5a6039215 100644 --- a/audiocraft/utils/utils.py +++ b/audiocraft/utils/utils.py @@ -5,9 +5,12 @@ # LICENSE file in the root directory of this source tree. from concurrent.futures import ProcessPoolExecutor -from functools import wraps +from contextlib import contextmanager +from functools import wraps, lru_cache import hashlib +import json import logging +from pathlib import Path import typing as tp import flashy @@ -20,6 +23,16 @@ from torch.nn.utils.rnn import pad_sequence logger = logging.getLogger(__name__) +def model_hash(model: torch.nn.Module) -> str: + """Return a model hash. This should allow us to track regressions in model init + from the logs of past experiments. + """ + hasher = hashlib.sha1() + for p in model.parameters(): + hasher.update(p.data.cpu().numpy().tobytes()) + return hasher.hexdigest() + + def dict_from_config(cfg: omegaconf.DictConfig) -> dict: """Convenience function to map an omegaconf configuration to a dictionary. @@ -172,7 +185,7 @@ def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> t assert len(lengths.shape) == 1, "Length shape should be 1 dimensional." final_length = lengths.max().item() if not max_len else max_len final_length = max(final_length, 1) # if all seqs are of len zero we don't want a zero-size tensor - return torch.arange(final_length)[None, :].to(lengths.device) < lengths[:, None] + return torch.arange(final_length, device=lengths.device)[None, :] < lengths[:, None] def hash_trick(word: str, vocab_size: int) -> int: @@ -232,3 +245,54 @@ def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tens padded_tensors = padded_tensors.transpose(0, 1) padded_tensors = padded_tensors.transpose(1, dim + 1) return padded_tensors, lens + + +# TODO: Move to flashy? +def copy_state(state: tp.Any, device: tp.Union[torch.device, str] = 'cpu', + dtype: tp.Optional[torch.dtype] = None) -> tp.Any: + if isinstance(state, torch.Tensor): + if dtype is None or not state.is_floating_point(): + dtype = state.dtype + return state.detach().to(device=device, dtype=dtype, copy=True) + elif isinstance(state, dict): + return {k: copy_state(v, device, dtype) for k, v in state.items()} + elif isinstance(state, list): + return [copy_state(v, device, dtype) for v in state] + + +# TODO: Move to flashy? +@contextmanager +def swap_state(model, state, **kwargs): + old_state = copy_state(model.state_dict()) + model.load_state_dict(state, **kwargs) + try: + yield + finally: + model.load_state_dict(old_state) + + +@lru_cache(None) +def warn_once(logger, msg): + """Warn about a given message only once.""" + logger.warning(msg) + + +def is_jsonable(x: tp.Any): + """Check if an object can be serialized into a json:""" + try: + json.dumps(x) + return True + except (TypeError, OverflowError): + return False + + +def load_clap_state_dict(clap_model, path: tp.Union[str, Path]): + """Wrapper around state dict loading of CLAP model + addressing compatibility issues between CLAP and AudioCraft + HuggingFace transformer version. + See: https://github.com/LAION-AI/CLAP/issues/118 + """ + from clap_module.factory import load_state_dict # type: ignore + pkg = load_state_dict(path) + pkg.pop('text_branch.embeddings.position_ids', None) + clap_model.model.load_state_dict(pkg) diff --git a/config/conditioner/chroma2music.yaml b/config/conditioner/chroma2music.yaml new file mode 100644 index 0000000000000000000000000000000000000000..91d37e758ef183678cff3f7a880b6bab2e36b03c --- /dev/null +++ b/config/conditioner/chroma2music.yaml @@ -0,0 +1,46 @@ +# @package __global__ + +classifier_free_guidance: + training_dropout: 0.2 + inference_coef: 3.0 + +attribute_dropout: + args: + active_on_eval: false + text: {} + wav: + self_wav: 0.5 + +fuser: + cross_attention_pos_emb: false + cross_attention_pos_emb_scale: 1 + sum: [] + prepend: [self_wav, description] + cross: [] + input_interpolate: [] + +conditioners: + self_wav: + model: chroma_stem + chroma_stem: + sample_rate: ${sample_rate} + n_chroma: 12 + radix2_exp: 14 + argmax: true + match_len_on_eval: false + eval_wavs: null + n_eval_wavs: 100 + cache_path: null + description: + model: t5 + t5: + name: t5-base + finetune: false + word_dropout: 0.2 + normalize_text: false + +dataset: + train: + merge_text_p: 0.25 + drop_desc_p: 0.5 + drop_other_p: 0.5 diff --git a/config/conditioner/clapemb2music.yaml b/config/conditioner/clapemb2music.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d44ac774492c3d80a0c29af330f6040a0a20264f --- /dev/null +++ b/config/conditioner/clapemb2music.yaml @@ -0,0 +1,44 @@ +# @package __global__ + +classifier_free_guidance: + training_dropout: 0.3 + inference_coef: 3.0 + +attribute_dropout: + text: {} + wav: {} + +fuser: + cross_attention_pos_emb: false + cross_attention_pos_emb_scale: 1 + sum: [] + prepend: [] + cross: [description] + input_interpolate: [] + +conditioners: + description: + model: clap + clap: + checkpoint: //reference/clap/music_audioset_epoch_15_esc_90.14.pt + model_arch: 'HTSAT-base' + enable_fusion: false + sample_rate: 48000 + max_audio_length: 10 + audio_stride: 1 + dim: 512 + attribute: description + normalize: true + quantize: true # use RVQ quantization + n_q: 12 + bins: 1024 + kmeans_iters: 50 + text_p: 0. # probability of using text embed at train time + cache_path: null + +dataset: + joint_embed_attributes: [description] + train: + merge_text_p: 0.25 + drop_desc_p: 0.5 + drop_other_p: 0.5 diff --git a/config/conditioner/none.yaml b/config/conditioner/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6055dc910cad46d80609aae57bb46b81f2663d70 --- /dev/null +++ b/config/conditioner/none.yaml @@ -0,0 +1,19 @@ +# @package __global__ + +# No conditioning + +classifier_free_guidance: + training_dropout: 0 + inference_coef: 1 + +attribute_dropout: + text: {} + wav: {} + +fuser: + sum: [] + prepend: [] + cross: [] + input_interpolate: [] + +conditioners: null diff --git a/config/conditioner/text2music.yaml b/config/conditioner/text2music.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2d0fe6cfa3fb33bcdb4f9fd16bd5ab4034c68b7b --- /dev/null +++ b/config/conditioner/text2music.yaml @@ -0,0 +1,30 @@ +# @package __global__ + +classifier_free_guidance: + training_dropout: 0.3 + inference_coef: 3.0 + +attribute_dropout: {} + +fuser: + cross_attention_pos_emb: false + cross_attention_pos_emb_scale: 1 + sum: [] + prepend: [] + cross: [description] + input_interpolate: [] + +conditioners: + description: + model: t5 + t5: + name: t5-base + finetune: false + word_dropout: 0.3 + normalize_text: false + +dataset: + train: + merge_text_p: 0.25 + drop_desc_p: 0.5 + drop_other_p: 0.5 diff --git a/config/conditioner/text2sound.yaml b/config/conditioner/text2sound.yaml new file mode 100644 index 0000000000000000000000000000000000000000..555d4b7c3cecf0ec06c8cb25440b2f426c098ad2 --- /dev/null +++ b/config/conditioner/text2sound.yaml @@ -0,0 +1,24 @@ +# @package __global__ + +classifier_free_guidance: + training_dropout: 0.1 + inference_coef: 3.0 + +attribute_dropout: {} + +fuser: + cross_attention_pos_emb: false + cross_attention_pos_emb_scale: 1 + sum: [] + prepend: [] + cross: [description] + input_interpolate: [] + +conditioners: + description: + model: t5 + t5: + name: t5-large + finetune: false + word_dropout: 0. + normalize_text: false diff --git a/config/config.yaml b/config/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6b0b7866eafac173fe7b056ad5920be1df57a947 --- /dev/null +++ b/config/config.yaml @@ -0,0 +1,75 @@ +# WARNING: This is the base configuration file shared across ALL solvers in AudioCraft +# Please don't update this file directly. Instead use distinct configuration files +# to override the below configuration. +defaults: + - _self_ + - dset: default + - solver: default + +device: cuda +dtype: float32 +autocast: false +autocast_dtype: bfloat16 +seed: 2036 +show: false # just show the model and its size and exit +continue_from: # continue from a given sig or path +execute_only: # can be set to generate/evaluate/valid to run that stage +execute_inplace: false # don't enforce continue_from to be set + # to enable inplace execution of the stage. This assume + # that you know what you are doing and execute stage + # preserving the original xp sig. +benchmark_no_load: false # if set to true, will repeat the same batch instead of loading them + +efficient_attention_backend: torch # can be torch or xformers. +num_threads: 1 # called with torch.set_num_thread. +mp_start_method: forkserver # multiprocessing method (spawn, fork or fork_server). + + +label: # use this if you want twice the same exp, with a name. + +# logging parameters +logging: + level: INFO + log_updates: 10 + log_tensorboard: false + log_wandb: false +tensorboard: + with_media_logging: false + name: # optional name for the experiment + sub_dir: # optional sub directory to store tensorboard data +wandb: + with_media_logging: true + project: # project name + name: # optional name for the experiment + group: # optional group + +# SLURM launcher configuration. +slurm: + gpus: 4 # convenience parameter, number of GPUs to use. + mem_per_gpu: 40 # in GB, total mem is automatically scaled with `gpus`. + time: 3600 + constraint: + partition: + comment: + setup: [] + exclude: '' + +# dora parameters +dora: + # Output folder for all artifacts of an experiment. + dir: /checkpoint/${oc.env:USER}/experiments/audiocraft/outputs + # The following entries will be ignored by dora when computing the unique XP signature. + # Note that slurm.* and dora.* are automatically ignored. + exclude: [ + 'device', 'wandb.*', 'tensorboard.*', 'logging.*', + 'dataset.num_workers', 'eval.num_workers', 'special.*', + 'metrics.visqol.bin', 'metrics.fad.bin', + 'execute_only', 'execute_best', 'generate.every', + 'optim.eager_sync', 'profiler.*', 'deadlock.*', + 'efficient_attention_backend', 'num_threads', 'mp_start_method', + ] + use_rendezvous: false + # for grids, always run from a clean repo, allowing reliable runs and storing + # the exact commit. Your repo must be absolutely pristine clean. + # Local `dora run` are not impacted for easier debugging. + git_save: true diff --git a/config/dset/audio/audiocaps_16khz.yaml b/config/dset/audio/audiocaps_16khz.yaml new file mode 100644 index 0000000000000000000000000000000000000000..14f5d6a4fcbf4426b7987d4427ca2d98d17d6c5b --- /dev/null +++ b/config/dset/audio/audiocaps_16khz.yaml @@ -0,0 +1,11 @@ +# @package __global__ + +# AudioCaps dataset +datasource: + max_sample_rate: 16000 + max_channels: 1 + + train: null # only evaluation set + valid: null # only evaluation set + evaluate: egs/audiocaps/audiocaps_16khz + generate: egs/audiocaps/audiocaps_16khz # identical to evaluate diff --git a/config/dset/audio/default.yaml b/config/dset/audio/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..80be23e999c6366cc89ebcf55af6b958c0e45158 --- /dev/null +++ b/config/dset/audio/default.yaml @@ -0,0 +1,10 @@ +# @package __global__ + +datasource: + max_sample_rate: ??? + max_channels: ??? + + train: ??? + valid: ??? + evaluate: ??? + generate: null diff --git a/config/dset/audio/example.yaml b/config/dset/audio/example.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d559d6d79a1cc05a82bb09f267c446258ef9ca55 --- /dev/null +++ b/config/dset/audio/example.yaml @@ -0,0 +1,10 @@ +# @package __global__ + +datasource: + max_sample_rate: 44100 + max_channels: 2 + + train: egs/example + valid: egs/example + evaluate: egs/example + generate: egs/example diff --git a/config/dset/audio/musiccaps_32khz.yaml b/config/dset/audio/musiccaps_32khz.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9d4eea0f7a521a47b9f673fecab075c5223d2b07 --- /dev/null +++ b/config/dset/audio/musiccaps_32khz.yaml @@ -0,0 +1,12 @@ +# @package __global__ + +# total samples obtained from MusicCaps = 5469 +# (out of 5521 due to AudioSet corrupted samples) +datasource: + max_sample_rate: 32000 + max_channels: 2 + + train: null # only evaluation set + valid: null # only evaluation set + evaluate: egs/musiccaps/musiccaps_32khz + generate: egs/musiccaps/musiccaps_32khz # identical to evaluate diff --git a/config/dset/default.yaml b/config/dset/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b5d730130e090b38a42984a8a87e1eea01cbf031 --- /dev/null +++ b/config/dset/default.yaml @@ -0,0 +1,10 @@ +# @package __global__ + +# WARNING: This is a base configuration file shared across ALL solvers in AudioCraft +# Please don't update this file directly. Instead use distinct configuration files +# to override the below configuration. +datasource: + train: ??? + valid: ??? + evaluate: ??? + generate: ??? diff --git a/config/dset/internal/music_10k_32khz.yaml b/config/dset/internal/music_10k_32khz.yaml new file mode 100644 index 0000000000000000000000000000000000000000..036628abfeaa89279790547bbb5b3ee9dd69cea3 --- /dev/null +++ b/config/dset/internal/music_10k_32khz.yaml @@ -0,0 +1,11 @@ +# @package __global__ + +# high quality music dataset with no artist overlap between splits +datasource: + max_sample_rate: 32000 + max_channels: 1 + + train: egs/music/music_10k_32khz/train + valid: egs/music/music_10k_32khz/valid + evaluate: egs/music/music_10k_32khz/test + generate: egs/music/music_10k_32khz/test # identical to evaluate diff --git a/config/dset/internal/music_400k_32khz.yaml b/config/dset/internal/music_400k_32khz.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7786880ab9c0464a0423d906c18d62bdf7194463 --- /dev/null +++ b/config/dset/internal/music_400k_32khz.yaml @@ -0,0 +1,10 @@ +# @package __global__ + +datasource: + max_sample_rate: 32000 + max_channels: 1 + + train: egs/music/music_400k_32khz/train + valid: egs/music/music_400k_32khz/valid + evaluate: egs/music/music_400k_32khz/test + generate: egs/music/music_400k_32khz/test # identical to evaluate diff --git a/config/dset/internal/sounds_16khz.yaml b/config/dset/internal/sounds_16khz.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4f3401a1b44ce300e22f3f64ef9c54d5c013c153 --- /dev/null +++ b/config/dset/internal/sounds_16khz.yaml @@ -0,0 +1,12 @@ +# @package __global__ + +# environmental sounds dataset compiling all datasets +# with applied filters on tags +datasource: + max_sample_rate: 16000 + max_channels: 1 + + train: egs/sound/sounds_16khz/train + valid: egs/sound/sounds_16khz/valid + evaluate: egs/sound/sounds_16khz/test + generate: egs/sound/sounds_16khz/test # identical to evaluate diff --git a/config/model/encodec/default.yaml b/config/model/encodec/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ec62c6c8ef9a686890bdca8b8f27a2f1c232205d --- /dev/null +++ b/config/model/encodec/default.yaml @@ -0,0 +1,54 @@ +# @package __global__ + +compression_model: encodec + +encodec: + autoencoder: seanet + quantizer: rvq + sample_rate: ${sample_rate} + channels: ${channels} + causal: false + renormalize: false + +seanet: + dimension: 128 + channels: ${channels} + causal: ${encodec.causal} + n_filters: 32 + n_residual_layers: 1 + ratios: [8, 5, 4, 2] + activation: ELU + activation_params: {"alpha": 1.} + norm: weight_norm + norm_params: {} + kernel_size: 7 + residual_kernel_size: 3 + last_kernel_size: 7 + dilation_base: 2 + pad_mode: constant + true_skip: true + compress: 2 + lstm: 2 + disable_norm_outer_blocks: 0 + # Specific encoder or decoder params. + # You can also override any param for the encoder or decoder only + # by using Hydra `+param=` syntax, i.e.` + # `+seanet.decoder.n_filters=64`. + decoder: + trim_right_ratio: 1.0 + final_activation: null + final_activation_params: null + encoder: {} + +rvq: + n_q: 8 + q_dropout: false + bins: 1024 + decay: 0.99 + kmeans_init: true + kmeans_iters: 50 + threshold_ema_dead_code: 2 + orthogonal_reg_weight: 0.0 + orthogonal_reg_active_codes_only: false + +no_quant: {} diff --git a/config/model/encodec/encodec_base_causal.yaml b/config/model/encodec/encodec_base_causal.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3ca555bcdc69433f172915400bb71c3b63e68681 --- /dev/null +++ b/config/model/encodec/encodec_base_causal.yaml @@ -0,0 +1,11 @@ +# @package __global__ + +defaults: + - encodec/default + +encodec: + causal: true + +rvq: + n_q: 32 + q_dropout: true diff --git a/config/model/encodec/encodec_large_nq4_s320.yaml b/config/model/encodec/encodec_large_nq4_s320.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5f2d77590afd8a81185358c705a6e42853e257c3 --- /dev/null +++ b/config/model/encodec/encodec_large_nq4_s320.yaml @@ -0,0 +1,13 @@ +# @package __global__ + +defaults: + - encodec/default + +seanet: + # default ratios are [8, 5, 4, 2] + n_filters: 64 + +rvq: + bins: 2048 + n_q: 4 + q_dropout: false diff --git a/config/model/encodec/encodec_large_nq4_s640.yaml b/config/model/encodec/encodec_large_nq4_s640.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3fcb7e87f4f700554164b0a58e9927b2f96a2c5a --- /dev/null +++ b/config/model/encodec/encodec_large_nq4_s640.yaml @@ -0,0 +1,13 @@ +# @package __global__ + +defaults: + - encodec/default + +seanet: + ratios: [8, 5, 4, 4] + n_filters: 64 + +rvq: + bins: 2048 + n_q: 4 + q_dropout: false diff --git a/config/model/lm/audiogen_lm.yaml b/config/model/lm/audiogen_lm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d17e7a93983e04492611d19183bf731865c67dd6 --- /dev/null +++ b/config/model/lm/audiogen_lm.yaml @@ -0,0 +1,36 @@ +# @package __global__ + +defaults: + - lm/default + - override /conditioner: text2sound + - override /model/lm/model_scale: small # prefer this group to set model scale instead of transformer_lm keys directly + +lm_model: transformer_lm + +codebooks_pattern: + modeling: delay + delay: + delays: [0, 1, 2, 3] + flatten_first: 0 + empty_initial: 0 + unroll: + flattening: [0, 1, 2, 3] + delays: [0, 0, 0, 0] + music_lm: + group_by: 2 + coarse_first: + delays: [0, 0, 0] + +transformer_lm: + n_q: 4 + card: 2048 + memory_efficient: true + bias_proj: false + bias_ff: false + bias_attn: false + norm_first: true + layer_scale: null + weight_init: gaussian + depthwise_init: current + zero_bias_init: true + attention_as_float32: false diff --git a/config/model/lm/default.yaml b/config/model/lm/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2d256ad14ef69d25d62c19b73599937c8546e79b --- /dev/null +++ b/config/model/lm/default.yaml @@ -0,0 +1,47 @@ +# @package __global__ +defaults: + - _self_ + - /model/lm/model_scale: base # prefer this group to set model scale instead of transformer_lm keys directly + +lm_model: transformer_lm + +codebooks_pattern: + modeling: parallel + +transformer_lm: + dim: 512 + num_heads: 8 + num_layers: 8 + hidden_scale: 4 + n_q: 8 # number of streams to model + card: 1024 + dropout: 0. + emb_lr: null + activation: gelu + norm_first: false # use pre-norm instead of post-norm + bias_ff: true # use bias for the feedforward + bias_attn: true # use bias for the attention + bias_proj: true # use bias for the output projections + past_context: null + causal: true + custom: false # use custom MHA implementation + memory_efficient: false # use flash attention + attention_as_float32: false # use float32 for the attention part, + # recommended at the moment when memory_efficient is True. + layer_scale: null + positional_embedding: sin # positional embedding strategy (sin, rope, or sin_rope). + xpos: false # apply xpos decay (rope only). + checkpointing: none # layer checkpointing method, can be none, torch, xformers_default. + # torch is the slowest but uses the least memory, + # xformers_default is somewhere in between. + weight_init: null # weight initialization (null, gaussian or uniform) + depthwise_init: null # perform depthwise initialization (null, current, global) + zero_bias_init: false # initialize bias to zero if bias in linears and + # if a weight_init method is used. + norm: layer_norm # normalization method to use in transformer. + cross_attention: false + qk_layer_norm: false + qk_layer_norm_cross: false + attention_dropout: null + kv_repeat: 1 + two_step_cfg: false # whether to do true 2 steps CFG, potentially resolving some padding issues or not... diff --git a/config/model/lm/model_scale/base.yaml b/config/model/lm/model_scale/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3da88d2305e4c380435de1a3eecfe311ecfc82f9 --- /dev/null +++ b/config/model/lm/model_scale/base.yaml @@ -0,0 +1,3 @@ +# @package __global__ + +# overrides nothing because default is already transformer base (~ 60M params) diff --git a/config/model/lm/model_scale/large.yaml b/config/model/lm/model_scale/large.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d355bfb93618003ac8994bc093eb7bc96ac60114 --- /dev/null +++ b/config/model/lm/model_scale/large.yaml @@ -0,0 +1,7 @@ +# @package _global_ + +# gpt2 inspired, even bigger (~3.3B params) +transformer_lm: + dim: 2048 + num_heads: 32 + num_layers: 48 diff --git a/config/model/lm/model_scale/medium.yaml b/config/model/lm/model_scale/medium.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c825d1ff6c3b8cc9ae4959a898e14b40409d95e8 --- /dev/null +++ b/config/model/lm/model_scale/medium.yaml @@ -0,0 +1,7 @@ +# @package _global_ + +# gpt2 like (~1.5B params) +transformer_lm: + dim: 1536 + num_heads: 24 + num_layers: 48 diff --git a/config/model/lm/model_scale/small.yaml b/config/model/lm/model_scale/small.yaml new file mode 100644 index 0000000000000000000000000000000000000000..88d89cb5ac1b183fb3a9092834cea83aa16c70a8 --- /dev/null +++ b/config/model/lm/model_scale/small.yaml @@ -0,0 +1,8 @@ +# @package _global_ + +# 300M Param. + +transformer_lm: + dim: 1024 + num_heads: 16 + num_layers: 24 diff --git a/config/model/lm/model_scale/xsmall.yaml b/config/model/lm/model_scale/xsmall.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e98d4370d4fe7497f12aeb58f092a88797d1afa1 --- /dev/null +++ b/config/model/lm/model_scale/xsmall.yaml @@ -0,0 +1,8 @@ +# @package _global_ +# just used for debugging or when we just want to populate the cache +# and do not care about training. + +transformer_lm: + dim: 64 + num_heads: 2 + num_layers: 2 diff --git a/config/model/lm/musicgen_lm.yaml b/config/model/lm/musicgen_lm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..be1fbc14d3bdfa4ce9d01841753bb0837687cb97 --- /dev/null +++ b/config/model/lm/musicgen_lm.yaml @@ -0,0 +1,36 @@ +# @package __global__ + +defaults: + - lm/default + - override /conditioner: text2music + - override /model/lm/model_scale: small # prefer this group to set model scale instead of transformer_lm keys directly + +lm_model: transformer_lm + +codebooks_pattern: + modeling: delay + delay: + delays: [0, 1, 2, 3] + flatten_first: 0 + empty_initial: 0 + unroll: + flattening: [0, 1, 2, 3] + delays: [0, 0, 0, 0] + music_lm: + group_by: 2 + coarse_first: + delays: [0, 0, 0] + +transformer_lm: + n_q: 4 + card: 2048 + memory_efficient: true + bias_proj: false + bias_ff: false + bias_attn: false + norm_first: true + layer_scale: null + weight_init: gaussian + depthwise_init: current + zero_bias_init: true + attention_as_float32: false diff --git a/config/model/none.yaml b/config/model/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1d4169f468d462c794ee6ed25017c3d78ae45d06 --- /dev/null +++ b/config/model/none.yaml @@ -0,0 +1,4 @@ +# @package __global__ + +# This file exist so that model is recognized as a config group +# by Hydra, and Dora. A bit weird we might need a better fix someday. diff --git a/config/model/score/basic.yaml b/config/model/score/basic.yaml new file mode 100644 index 0000000000000000000000000000000000000000..75fbc3783942602beaddaa38d0aca977aeee2dda --- /dev/null +++ b/config/model/score/basic.yaml @@ -0,0 +1,17 @@ +# @package _global_ + +diffusion_unet: + hidden: 48 + depth: 4 + res_blocks: 1 + norm_groups: 4 + kernel: 8 + stride: 4 + growth: 4 + max_channels: 10_000 + dropout: 0. + emb_all_layers: true + bilstm: false + codec_dim: null + transformer: false + cross_attention: false \ No newline at end of file diff --git a/config/solver/audiogen/audiogen_base_16khz.yaml b/config/solver/audiogen/audiogen_base_16khz.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dd6aee785c74db19ce9d6f488e68e6eeb471c026 --- /dev/null +++ b/config/solver/audiogen/audiogen_base_16khz.yaml @@ -0,0 +1,70 @@ +# @package __global__ + +# This is the training loop solver +# for the base AudioGen model (text-to-sound) +# on monophonic audio sampled at 16 kHz +# using a similar EnCodec+LM setup to MusicGen +defaults: + - audiogen/default + - /model: lm/audiogen_lm + - override /dset: audio/default + - _self_ + +autocast: true +autocast_dtype: float16 + +# EnCodec large trained on mono-channel music audio sampled at 16khz +# with a total stride of 320 leading to 50 frames/s. +# rvq.n_q=4, rvq.bins=2048, no quantization dropout +# (transformer_lm card and n_q must be compatible) +compression_model_checkpoint: //reference/bd44a852/checkpoint.th + +channels: 1 +sample_rate: 16000 + +deadlock: + use: true # deadlock detection + +dataset: + batch_size: 128 # matching AudioGen paper setup (256 * mix_p=0.5 = 128) + num_workers: 10 + segment_duration: 10 + min_segment_ratio: 1.0 + sample_on_weight: false # Uniform sampling all the way + sample_on_duration: false # Uniform sampling all the way + external_metadata_source: null + # sample mixing augmentation at train time + train: + batch_size: 256 # matching AudioGen paper setup + aug_p: 0.5 # perform audio mixing 50% of the time + mix_p: 0.5 # proportion of batch items mixed together + # important: note that this will reduce the + # actual batch size used at train time + # which will be equal to mix_p * batch_size + mix_snr_low: -5 + mix_snr_high: 5 + mix_min_overlap: 0.5 + +generate: + lm: + use_sampling: true + top_k: 250 + top_p: 0.0 + +optim: + epochs: 100 + optimizer: adamw + lr: 5e-4 + ema: + use: true + updates: 10 + device: cuda + +logging: + log_tensorboard: true + +schedule: + lr_scheduler: inverse_sqrt + inverse_sqrt: + warmup: 3000 + warmup_init_lr: 0.0 diff --git a/config/solver/audiogen/debug.yaml b/config/solver/audiogen/debug.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a1dd24626e611418c4de5c1d5ac6fca50dc70876 --- /dev/null +++ b/config/solver/audiogen/debug.yaml @@ -0,0 +1,61 @@ +# @package __global__ + +# This is a minimal debugging configuration +# for MusicGen training solver +defaults: + - audiogen/default + - /model: lm/audiogen_lm + - override /model/lm/model_scale: xsmall + - override /dset: audio/example + - _self_ + +autocast: false +compression_model_checkpoint: null +transformer_lm: + n_q: 4 + card: 400 + +conditioners: + description: + model: t5 + t5: + name: t5-small + +codebooks_pattern: + modeling: parallel + +channels: 1 +sample_rate: 16000 + +deadlock: + use: false # deadlock detection + +dataset: + batch_size: 4 + segment_duration: 5 + sample_on_weight: false # Uniform sampling all the way + sample_on_duration: false # Uniform sampling all the way + +generate: + audio: + strategy: peak + lm: + use_sampling: false + top_k: 0 + top_p: 0.0 + +checkpoint: + save_every: 0 + keep_last: 0 + +optim: + epochs: 2 + updates_per_epoch: 10 + optimizer: adamw + lr: 1e-4 + +logging: + log_tensorboard: true + +schedule: + lr_scheduler: null diff --git a/config/solver/audiogen/default.yaml b/config/solver/audiogen/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..afee63c65e0dd7350e3e89d2133bbca221d17631 --- /dev/null +++ b/config/solver/audiogen/default.yaml @@ -0,0 +1,40 @@ +# @package __global__ + +defaults: + - /solver/musicgen/default + - _self_ + - /solver/audiogen/evaluation: none + - override /dset: audio/default + +# See config/solver/musicgen/default.yaml for a list of possible values. +# We only keep the most important here. + +autocast: true +autocast_dtype: float16 + +solver: audiogen +sample_rate: ??? +channels: ??? +compression_model_checkpoint: ??? + +tokens: + padding_with_special_token: false + +dataset: + batch_size: 128 + segment_duration: 10 + min_segment_ratio: 1.0 # lower values such as 0.5 result in generations with a lot of silence. + +optim: + epochs: 100 + updates_per_epoch: 2000 + lr: 1e-4 + optimizer: adamw + max_norm: 1.0 + adam: + betas: [0.9, 0.95] + weight_decay: 0.1 + eps: 1e-8 + +schedule: + lr_scheduler: null diff --git a/config/solver/audiogen/evaluation/none.yaml b/config/solver/audiogen/evaluation/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1e739995ed6488700527529862a7a24f1afdcc7a --- /dev/null +++ b/config/solver/audiogen/evaluation/none.yaml @@ -0,0 +1,5 @@ +# @package __global__ + +dataset: + evaluate: + num_samples: 10000 diff --git a/config/solver/audiogen/evaluation/objective_eval.yaml b/config/solver/audiogen/evaluation/objective_eval.yaml new file mode 100644 index 0000000000000000000000000000000000000000..32fcc10033f3c3ff317216fe2876c65c6834e59b --- /dev/null +++ b/config/solver/audiogen/evaluation/objective_eval.yaml @@ -0,0 +1,29 @@ +# @package __global__ + +# Setup for execute only on audiocaps for audio generation +# evaluation with objective metrics +# execute_only=evaluate + +dataset: + max_audio_duration: null + # ensure the proper values are broadcasted here for evaluate + evaluate: + min_audio_duration: 1. # some metrics requires a minimum audio length + max_audio_duration: null # all samples from audiocaps should be ~10s + num_samples: null + segment_duration: null + generate: + min_audio_duration: 1. + max_audio_duration: null + num_samples: 500 + +evaluate: + metrics: + fad: true + kld: true + text_consistency: true + +metrics: + kld: + passt: + pretrained_length: 10 # similarly to reported results in AudioGen paper diff --git a/config/solver/compression/debug.yaml b/config/solver/compression/debug.yaml new file mode 100644 index 0000000000000000000000000000000000000000..54dac175278d4ff509b0e44905d6b6195441f2c6 --- /dev/null +++ b/config/solver/compression/debug.yaml @@ -0,0 +1,55 @@ +# @package __global__ + +defaults: + - compression/default + - /model: encodec/encodec_base_causal + - override /dset: audio/example + - _self_ + +channels: 1 +sample_rate: 16000 + +# debug config uses just L1 +losses: + adv: 0. + feat: 0. + l1: 1. + mel: 0. + msspec: 0. +# no balancer +balancer: + balance_grads: false + ema_decay: 1. + total_norm: 1. + per_batch_item: false +# no adversaries +adversarial: + adversaries: [] + adv_loss: hinge + feat_loss: l1 + +# faster model for local dev +seanet: + dimension: 16 + n_filters: 4 + +# very small dataset +dataset: + batch_size: 8 + num_workers: 10 + num_samples: 100 + segment_duration: 1 + evaluate: + batch_size: 32 + generate: + batch_size: 1 + num_samples: 5 + segment_duration: 10 + +# limited training +evaluate: + every: 5 +generate: + every: 5 +optim: + epochs: 50 diff --git a/config/solver/compression/default.yaml b/config/solver/compression/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..41c812ba9ff8afe7ee10302ad5b9f05b745877d9 --- /dev/null +++ b/config/solver/compression/default.yaml @@ -0,0 +1,160 @@ +# @package __global__ + +defaults: + - ../default + - override /dset: audio/default + - _self_ + +solver: compression +sample_rate: ??? +channels: ??? + +# loss balancing +losses: + adv: 4. + feat: 4. + l1: 0.1 + mel: 0. + msspec: 2. + sisnr: 0. +balancer: + balance_grads: true + ema_decay: 0.999 + per_batch_item: true + total_norm: 1. + +adversarial: + every: 1 + adversaries: [msstftd] + adv_loss: hinge + feat_loss: l1 + +# losses hyperparameters +l1: {} +l2: {} +mrstft: + factor_sc: .5 + factor_mag: .5 + normalized: false +mel: + sample_rate: ${sample_rate} + n_fft: 1024 + hop_length: 256 + win_length: 1024 + n_mels: 64 + f_min: 64 + f_max: null + normalized: false + floor_level: 1e-5 +sisnr: + sample_rate: ${sample_rate} + segment: 5. +msspec: + sample_rate: ${sample_rate} + range_start: 6 + range_end: 11 + n_mels: 64 + f_min: 64 + f_max: null + normalized: true + alphas: false + floor_level: 1e-5 + +# metrics +metrics: + visqol: + mode: audio + bin: null # path to visqol install + model: tcdaudio14_aacvopus_coresv_svrnsim_n.68_g.01_c1.model # visqol v3 + +# adversaries hyperparameters +msstftd: + in_channels: 1 + out_channels: 1 + filters: 32 + norm: weight_norm + n_ffts: [1024, 2048, 512, 256, 128] + hop_lengths: [256, 512, 128, 64, 32] + win_lengths: [1024, 2048, 512, 256, 128] + activation: LeakyReLU + activation_params: {negative_slope: 0.3} +msd: + in_channels: 1 + out_channels: 1 + scale_norms: [spectral_norm, weight_norm, weight_norm] + kernel_sizes: [5, 3] + filters: 16 + max_filters: 1024 + downsample_scales: [4, 4, 4, 4] + inner_kernel_sizes: null + groups: [4, 4, 4, 4] + strides: null + paddings: null + activation: LeakyReLU + activation_params: {negative_slope: 0.3} +mpd: + in_channels: 1 + out_channels: 1 + periods: [2, 3, 5, 7, 11] + n_layers: 5 + kernel_size: 5 + stride: 3 + filters: 8 + filter_scales: 4 + max_filters: 1024 + activation: LeakyReLU + activation_params: {negative_slope: 0.3} + norm: weight_norm + +# data hyperparameters +dataset: + batch_size: 64 + num_workers: 10 + segment_duration: 1 + train: + num_samples: 500000 + valid: + num_samples: 10000 + evaluate: + batch_size: 32 + num_samples: 10000 + generate: + batch_size: 32 + num_samples: 50 + segment_duration: 10 + +# solver hyperparameters +evaluate: + every: 25 + num_workers: 5 + metrics: + visqol: false + sisnr: true +generate: + every: 25 + num_workers: 5 + audio: + sample_rate: ${sample_rate} + +# checkpointing schedule +checkpoint: + save_last: true + save_every: 25 + keep_last: 10 + keep_every_states: null + +# optimization hyperparameters +optim: + epochs: 200 + updates_per_epoch: 2000 + lr: 3e-4 + max_norm: 0. + optimizer: adam + adam: + betas: [0.5, 0.9] + weight_decay: 0. + ema: + use: true # whether to use EMA or not + updates: 1 # update at every step + device: ${device} # device for EMA, can be put on GPU if more frequent updates + decay: 0.99 # EMA decay value, if null, no EMA is used diff --git a/config/solver/compression/encodec_audiogen_16khz.yaml b/config/solver/compression/encodec_audiogen_16khz.yaml new file mode 100644 index 0000000000000000000000000000000000000000..654deaa01ba9cace3f7144cc91921791c081b32a --- /dev/null +++ b/config/solver/compression/encodec_audiogen_16khz.yaml @@ -0,0 +1,10 @@ +# @package __global__ + +defaults: + - compression/default + - /model: encodec/encodec_large_nq4_s320 + - override /dset: audio/default + - _self_ + +channels: 1 +sample_rate: 16000 diff --git a/config/solver/compression/encodec_base_24khz.yaml b/config/solver/compression/encodec_base_24khz.yaml new file mode 100644 index 0000000000000000000000000000000000000000..018ad1cd61af84b616ad3088f055e8eaa36729eb --- /dev/null +++ b/config/solver/compression/encodec_base_24khz.yaml @@ -0,0 +1,10 @@ +# @package __global__ + +defaults: + - compression/default + - /model: encodec/encodec_base_causal + - override /dset: audio/default + - _self_ + +channels: 1 +sample_rate: 24000 diff --git a/config/solver/compression/encodec_musicgen_32khz.yaml b/config/solver/compression/encodec_musicgen_32khz.yaml new file mode 100644 index 0000000000000000000000000000000000000000..eca4b90fb221372dace164fe59bb15822207a980 --- /dev/null +++ b/config/solver/compression/encodec_musicgen_32khz.yaml @@ -0,0 +1,10 @@ +# @package __global__ + +defaults: + - compression/default + - /model: encodec/encodec_large_nq4_s640 + - override /dset: audio/default + - _self_ + +channels: 1 +sample_rate: 32000 diff --git a/config/solver/default.yaml b/config/solver/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d7452ea1e415516dceaaae86d692cbb8c811bd57 --- /dev/null +++ b/config/solver/default.yaml @@ -0,0 +1,108 @@ +# @package __global__ + +# WARNING: This is a base configuration file shared across ALL solvers in AudioCraft +# Please don't update this file directly. Instead use distinct configuration files +# to override the below configuration. +solver: ??? + +fsdp: + use: false # should we use FSDP. + param_dtype: float16 # equivalent to autocast_dtype for FSDP. + reduce_dtype: float32 # gradient averaging dtype, float32 will give max stability. + buffer_dtype: float32 # dtype used for buffers, we don't have much buffers, so let's leave it. + sharding_strategy: shard_grad_op # can be shard_grad_op or full_shard. + # full_shard will use less memory but slower ?? + per_block: true # If True, uses nested FSDP. + +profiler: + enabled: false + +deadlock: + use: false + timeout: 600 + +dataset: + batch_size: ??? + num_workers: 10 + segment_duration: null + num_samples: null + return_info: false + shuffle: false + sample_on_duration: true + sample_on_weight: true + min_segment_ratio: 0.5 + train: + num_samples: null + shuffle: true + shuffle_seed: 0 # if you want to sample the data differently. + permutation_on_files: false + valid: + num_samples: null + evaluate: + num_samples: null + generate: + num_samples: null + return_info: true + +checkpoint: + save_last: true + save_every: null + keep_last: null + keep_every_states: null + +generate: + every: null + path: 'samples' + audio: + format: 'mp3' + strategy: 'clip' + sample_rate: null + lm: + use_sampling: false + temp: 1.0 + top_k: 0 + top_p: 0.0 +evaluate: + every: null + num_workers: 5 + truncate_audio: null + fixed_generation_duration: null # in secs + metrics: + base: true # run default evaluation (e.g. like train/valid stage) + +optim: + epochs: ??? + updates_per_epoch: null + lr: ??? + optimizer: ??? + adam: + betas: [0.9, 0.999] + weight_decay: 0. + ema: + use: false # whether to use EMA or not + updates: ${optim.updates_per_epoch} # frequency of updates of the EMA + device: cpu # device for EMA, can be put on GPU if more frequent updates + decay: 0.99 # EMA decay value, if null, no EMA is used + +schedule: + lr_scheduler: null + step: + step_size: null + gamma: null + exponential: + lr_decay: null + cosine: + warmup: null + lr_min_ratio: 0.0 + cycle_length: 1.0 + polynomial_decay: + warmup: null + zero_lr_warmup_steps: 0 + end_lr: 0.0 + power: 1 + inverse_sqrt: + warmup: null + warmup_init_lr: 0.0 + linear_warmup: + warmup: null + warmup_init_lr: 0.0 diff --git a/config/solver/diffusion/debug.yaml b/config/solver/diffusion/debug.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bc27c53486f7215a080d167032972402b90f5c77 --- /dev/null +++ b/config/solver/diffusion/debug.yaml @@ -0,0 +1,106 @@ +# @package __global__ + +defaults: + - /solver/default + - /model: score/basic + - override /dset: audio/default + - _self_ + +solver: diffusion + +sample_rate: 16000 +channels: 1 +compression_model_checkpoint: //sig/5091833e +n_q: 2 # number of codebooks to keep + +dataset: + batch_size: 8 + num_workers: 10 + segment_duration: 1 + train: + num_samples: 100 + valid: + num_samples: 100 + evaluate: + batch_size: 8 + num_samples: 10 + generate: + batch_size: 8 + num_samples: 10 + segment_duration: 10 + +loss: + kind: mse + norm_power: 0. + +valid: + every: 1 + +evaluate: + every: 5 + num_workers: 5 + metrics: + visqol: false + sisnr: false + rvm: true + +generate: + every: 5 + num_workers: 5 + audio: + sample_rate: ${sample_rate} + +checkpoint: + save_last: true + save_every: 25 + keep_last: 10 + keep_every_states: null + + +optim: + epochs: 50 + updates_per_epoch: 2000 + lr: 2e-4 + max_norm: 0 + optimizer: adam + adam: + betas: [0.9, 0.999] + weight_decay: 0. + ema: + use: true # whether to use EMA or not + updates: 1 # update at every step + device: ${device} # device for EMA, can be put on GPU if more frequent updates + decay: 0.99 # EMA decay value, if null, no EMA is used + +processor: + name: multi_band_processor + use: false + n_bands: 8 + num_samples: 10_000 + power_std: 1. + +resampling: + use: false + target_sr: 16000 + +filter: + use: false + n_bands: 4 + idx_band: 0 + cutoffs: null + +schedule: + repartition: "power" + variable_step_batch: true + beta_t0: 1.0e-5 + beta_t1: 2.9e-2 + beta_exp: 7.5 + num_steps: 1000 + variance: 'beta' + clip: 5. + rescale: 1. + n_bands: null + noise_scale: 1.0 + +metrics: + num_stage: 4 diff --git a/config/solver/diffusion/default.yaml b/config/solver/diffusion/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3793d4d08d912db575c022a6803a8909c2b25273 --- /dev/null +++ b/config/solver/diffusion/default.yaml @@ -0,0 +1,107 @@ +# @package __global__ + +defaults: + - /solver/default + - /model: score/basic + - override /dset: audio/default + - _self_ + +solver: diffusion + +sample_rate: ??? +channels: ??? +compression_model_checkpoint: ??? +n_q: ??? # number of codebooks to keep + + +dataset: + batch_size: 128 + num_workers: 10 + segment_duration: 1 + train: + num_samples: 500000 + valid: + num_samples: 10000 + evaluate: + batch_size: 16 + num_samples: 10000 + generate: + batch_size: 32 + num_samples: 50 + segment_duration: 10 + audio: + sample_rate: ${sample_rate} + +loss: + kind: mse + norm_power: 0. + +valid: + every: 1 + +evaluate: + every: 20 + num_workers: 5 + metrics: + visqol: false + sisnr: false + rvm: true + +generate: + every: 25 + num_workers: 5 + +checkpoint: + save_last: true + save_every: 25 + keep_last: 10 + keep_every_states: null + + +optim: + epochs: 20000 + updates_per_epoch: 2000 + lr: 2e-4 + max_norm: 0 + optimizer: adam + adam: + betas: [0.9, 0.999] + weight_decay: 0. + ema: + use: true # whether to use EMA or not + updates: 1 # update at every step + device: ${device} # device for EMA, can be put on GPU if more frequent updates + decay: 0.99 # EMA decay value, if null, no EMA is used + +processor: + name: multi_band_processor + use: false + n_bands: 8 + num_samples: 10_000 + power_std: 1. + +resampling: + use: false + target_sr: 16000 + +filter: + use: false + n_bands: 4 + idx_band: 0 + cutoffs: null + +schedule: + repartition: "power" + variable_step_batch: true + beta_t0: 1.0e-5 + beta_t1: 2.9e-2 + beta_exp: 7.5 + num_steps: 1000 + variance: 'beta' + clip: 5. + rescale: 1. + n_bands: null + noise_scale: 1.0 + +metrics: + num_stage: 4 diff --git a/config/solver/diffusion/encodec_24khz.yaml b/config/solver/diffusion/encodec_24khz.yaml new file mode 100644 index 0000000000000000000000000000000000000000..774e88f43d54980daef0c68d11717ddb7a214db1 --- /dev/null +++ b/config/solver/diffusion/encodec_24khz.yaml @@ -0,0 +1,11 @@ +# @package __global__ + +defaults: + - diffusion/default + - _self_ + + +sample_rate: 24000 +channels: 1 +compression_model_checkpoint: //pretrained/facebook/encodec_24khz +n_q: 4 # num quantizers, 3kbps diff --git a/config/solver/musicgen/debug.yaml b/config/solver/musicgen/debug.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9734d1bf975065ab4e185f8831f9960335810655 --- /dev/null +++ b/config/solver/musicgen/debug.yaml @@ -0,0 +1,61 @@ +# @package __global__ + +# This is a minimal debugging configuration +# for MusicGen training solver +defaults: + - musicgen/default + - /model: lm/musicgen_lm + - override /model/lm/model_scale: xsmall + - override /dset: audio/example + - _self_ + +autocast: false +compression_model_checkpoint: //pretrained/debug_compression_model +transformer_lm: + n_q: 4 + card: 400 + +conditioners: + description: + model: t5 + t5: + name: t5-small + +codebooks_pattern: + modeling: parallel + +channels: 1 +sample_rate: 32000 + +deadlock: + use: false # deadlock detection + +dataset: + batch_size: 4 + segment_duration: 5 + sample_on_weight: false # Uniform sampling all the way + sample_on_duration: false # Uniform sampling all the way + +generate: + audio: + strategy: peak + lm: + use_sampling: false + top_k: 0 + top_p: 0.0 + +checkpoint: + save_every: 0 + keep_last: 0 + +optim: + epochs: 2 + updates_per_epoch: 10 + optimizer: adamw + lr: 1e-4 + +logging: + log_tensorboard: true + +schedule: + lr_scheduler: null diff --git a/config/solver/musicgen/default.yaml b/config/solver/musicgen/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8bdf9c74ba4ea5385dd1c2dabd0ae4905dbb501e --- /dev/null +++ b/config/solver/musicgen/default.yaml @@ -0,0 +1,129 @@ +# @package __global__ + +defaults: + - /solver/default + - /conditioner: none + - _self_ + - /solver/musicgen/evaluation: none + - override /dset: audio/default + +autocast: true +autocast_dtype: float16 + +solver: musicgen +sample_rate: ??? +channels: ??? +compression_model_checkpoint: ??? +# The following will set the num codebooks on the underlying +# model, this might be different from the actual value for n_q +# given to the transformer, when the model output is postprocessed, for instance +# for stereo channels. If not provided, default value for the compression model +# will be used. +compression_model_n_q: null + +tokens: + padding_with_special_token: false + +interleave_stereo_codebooks: + use: false + per_timestep: false + +cache: + path: + write: false + write_shard: 0 + write_num_shards: 1 + + +dataset: + batch_size: 128 + num_workers: 10 + segment_duration: 30 + min_segment_ratio: 0.8 # lower values such as 0.5 result in generations with a lot of silence. + return_info: true + train: + num_samples: 1000000 # need a randomly large number here for AudioDataset + valid: + num_samples: 10000 + generate: + num_samples: 50 + +metrics: + fad: + use_gt: false + model: tf + tf: + bin: null # path to local frechet_audio_distance code + model_path: //reference/fad/vggish_model.ckpt + kld: + use_gt: false + model: passt + passt: + pretrained_length: 20 + text_consistency: + use_gt: false + model: clap + clap: + model_path: //reference/clap/music_audioset_epoch_15_esc_90.14.pt + model_arch: 'HTSAT-base' + enable_fusion: false + chroma_cosine: + use_gt: false + model: chroma_base + chroma_base: + sample_rate: ${sample_rate} + n_chroma: 12 + radix2_exp: 14 + argmax: true + +generate: + every: 25 + num_workers: 5 + path: samples + audio: + format: wav + strategy: loudness + sample_rate: ${sample_rate} + loudness_headroom_db: 14 + lm: + prompted_samples: true + unprompted_samples: true + gen_gt_samples: false + prompt_duration: null # if not set, will use dataset.generate.segment_duration / 4 + gen_duration: null # if not set, will use dataset.generate.segment_duration + remove_prompts: false + # generation params + use_sampling: false + temp: 1.0 + top_k: 0 + top_p: 0.0 +evaluate: + every: 25 + num_workers: 5 + metrics: + base: false + fad: false + kld: false + text_consistency: false + chroma_cosine: false + +checkpoint: + save_last: true + save_every: 50 + keep_last: 10 + keep_every_states: null + +optim: + epochs: 200 + updates_per_epoch: 2000 + lr: 1e-4 + optimizer: adamw + max_norm: 1.0 + eager_sync: true + adam: + betas: [0.9, 0.95] + weight_decay: 0.1 + eps: 1e-8 + +schedule: + lr_scheduler: null diff --git a/config/solver/musicgen/evaluation/none.yaml b/config/solver/musicgen/evaluation/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1e739995ed6488700527529862a7a24f1afdcc7a --- /dev/null +++ b/config/solver/musicgen/evaluation/none.yaml @@ -0,0 +1,5 @@ +# @package __global__ + +dataset: + evaluate: + num_samples: 10000 diff --git a/config/solver/musicgen/evaluation/objective_eval.yaml b/config/solver/musicgen/evaluation/objective_eval.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4881e9d86cddf36b306a75fb498253e1e12ec5be --- /dev/null +++ b/config/solver/musicgen/evaluation/objective_eval.yaml @@ -0,0 +1,24 @@ +# @package __global__ + +# Setup for execute only on musiccaps for audio generation +# evaluation with objective metrics +# execute_only=evaluate + +dataset: + max_audio_duration: null + # ensure the proper values are broadcasted here for evaluate + evaluate: + min_audio_duration: 1. # some metrics requires a minimum audio length + max_audio_duration: null # all samples from musiccaps should be < 20s + num_samples: null + segment_duration: null + generate: + min_audio_duration: 1. + max_audio_duration: null + num_samples: 500 + +evaluate: + metrics: + fad: true + kld: true + text_consistency: true diff --git a/config/solver/musicgen/musicgen_base_32khz.yaml b/config/solver/musicgen/musicgen_base_32khz.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b32c9c898a70718f91af862caa79f5553a5107e1 --- /dev/null +++ b/config/solver/musicgen/musicgen_base_32khz.yaml @@ -0,0 +1,55 @@ +# @package __global__ + +# This is the training loop solver +# for the base MusicGen model (text-to-music) +# on monophonic audio sampled at 32 kHz +defaults: + - musicgen/default + - /model: lm/musicgen_lm + - override /dset: audio/default + - _self_ + +autocast: true +autocast_dtype: float16 + +# EnCodec large trained on mono-channel music audio sampled at 32khz +# with a total stride of 640 leading to 50 frames/s. +# rvq.n_q=4, rvq.bins=2048, no quantization dropout +# (transformer_lm card and n_q must be compatible) +compression_model_checkpoint: //pretrained/facebook/encodec_32khz + +channels: 1 +sample_rate: 32000 + +deadlock: + use: true # deadlock detection + +dataset: + batch_size: 192 # 32 GPUs + sample_on_weight: false # Uniform sampling all the way + sample_on_duration: false # Uniform sampling all the way + +generate: + lm: + use_sampling: true + top_k: 250 + top_p: 0.0 + +optim: + epochs: 500 + optimizer: dadam + lr: 1 + ema: + use: true + updates: 10 + device: cuda + +logging: + log_tensorboard: true + +schedule: + lr_scheduler: cosine + cosine: + warmup: 4000 + lr_min_ratio: 0.0 + cycle_length: 1.0 diff --git a/config/solver/musicgen/musicgen_melody_32khz.yaml b/config/solver/musicgen/musicgen_melody_32khz.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1ad3e0aeeb9583887d6e8ecd6d32a3dc69e102ed --- /dev/null +++ b/config/solver/musicgen/musicgen_melody_32khz.yaml @@ -0,0 +1,56 @@ +# @package __global__ + +# This is the training loop solver +# for the melody MusicGen model (text+chroma to music) +# on monophonic audio sampled at 32 kHz +defaults: + - musicgen/default + - /model: lm/musicgen_lm + - override /conditioner: chroma2music + - override /dset: audio/default + - _self_ + +autocast: true +autocast_dtype: float16 + +# EnCodec large trained on mono-channel music audio sampled at 32khz +# with a total stride of 640 leading to 50 frames/s. +# rvq.n_q=4, rvq.bins=2048, no quantization dropout +# (transformer_lm card and n_q must be compatible) +compression_model_checkpoint: //pretrained/facebook/encodec_32khz + +channels: 1 +sample_rate: 32000 + +deadlock: + use: true # deadlock detection + +dataset: + batch_size: 192 # 32 GPUs + sample_on_weight: false # Uniform sampling all the way + sample_on_duration: false # Uniform sampling all the way + +generate: + lm: + use_sampling: true + top_k: 250 + top_p: 0.0 + +optim: + epochs: 500 + optimizer: dadam + lr: 1 + ema: + use: true + updates: 10 + device: cuda + +logging: + log_tensorboard: true + +schedule: + lr_scheduler: cosine + cosine: + warmup: 4000 + lr_min_ratio: 0.0 + cycle_length: 1.0 diff --git a/config/teams/default.yaml b/config/teams/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..407066df1e154208af2823a6e46d16df381c5d42 --- /dev/null +++ b/config/teams/default.yaml @@ -0,0 +1,12 @@ +default: + dora_dir: /tmp/audiocraft_${oc.env:USER} + partitions: + global: debug + team: debug + reference_dir: /tmp +darwin: # if we detect we are on a Mac, then most likely we are doing unit testing etc. + dora_dir: /tmp/audiocraft_${oc.env:USER} + partitions: + global: debug + team: debug + reference_dir: /tmp diff --git a/config/teams/labs.yaml b/config/teams/labs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..da350a94bc5758531ced5d9e4332624fe86f3d57 --- /dev/null +++ b/config/teams/labs.yaml @@ -0,0 +1,28 @@ +aws: + dora_dir: /fsx-audio-craft-llm/${oc.env:USER}/experiments/audiocraft/outputs + partitions: + global: learnlab + team: learnlab + reference_dir: /fsx-audio-craft-llm/shared/audiocraft/reference + dataset_mappers: + "^/checkpoint/[a-z]+": "/fsx-audio-craft-llm" +fair: + dora_dir: /checkpoint/${oc.env:USER}/experiments/audiocraft/outputs + partitions: + global: learnlab + team: learnlab + reference_dir: /large_experiments/audiocraft/reference + dataset_mappers: + "^/datasets01/datasets01": "/datasets01" +darwin: + dora_dir: /tmp/audiocraft_${oc.env:USER} + partitions: + global: debug + team: debug + reference_dir: /tmp +rsc: + dora_dir: /checkpoint/audiocraft/${oc.env:USER}/experiments/audiocraft/outputs + partitions: + global: learn + team: learn + reference_dir: /checkpoint/audiocraft/shared/reference diff --git a/dataset/example/electro_1.json b/dataset/example/electro_1.json new file mode 100644 index 0000000000000000000000000000000000000000..eeffc95038a1e031fad5598f822ddf2538d7f4da --- /dev/null +++ b/dataset/example/electro_1.json @@ -0,0 +1 @@ +{"key": "", "artist": "Voyager I", "sample_rate": 48000, "file_extension": "mp3", "description": "A cool song from Voyager.", "keywords": "bright, pulsing, cool", "duration": 15.0, "bpm": "", "genre": "electronic", "title": "Enracinement", "name": "electro_1", "instrument": "Mix", "moods": ["uplifting", "motivational"]} diff --git a/dataset/example/electro_1.mp3 b/dataset/example/electro_1.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..8fa509266df4ee76519b82bfbea247cb0b18bcda Binary files /dev/null and b/dataset/example/electro_1.mp3 differ diff --git a/dataset/example/electro_2.json b/dataset/example/electro_2.json new file mode 100644 index 0000000000000000000000000000000000000000..3ee91c89c1d4b603f3e4d3fcc029618dc110e730 --- /dev/null +++ b/dataset/example/electro_2.json @@ -0,0 +1 @@ +{"key": "", "artist": "Voyager I", "sample_rate": 44100, "file_extension": "mp3", "description": "This is an electronic song sending positive vibes.", "keywords": "", "duration": 20.0, "bpm": "", "genre": "electronic", "title": "Untitled song", "name": "electro_2", "instrument": "Mix", "moods": []} diff --git a/dataset/example/electro_2.mp3 b/dataset/example/electro_2.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..01ab323e4322d08546635861959b868c3d7b416b Binary files /dev/null and b/dataset/example/electro_2.mp3 differ diff --git a/demos/audiogen_demo.ipynb b/demos/audiogen_demo.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..d3ad73fbcf172ea291ee4d73729b35af75ccedaa --- /dev/null +++ b/demos/audiogen_demo.ipynb @@ -0,0 +1,175 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# AudioGen\n", + "Welcome to AudioGen's demo jupyter notebook. Here you will find a series of self-contained examples of how to use AudioGen in different settings.\n", + "\n", + "First, we start by initializing AudioGen. For now, we provide only a medium sized model for AudioGen: `facebook/audiogen-medium` - 1.5B transformer decoder. \n", + "\n", + "**Important note:** This variant is different from the original AudioGen model presented at [\"AudioGen: Textually-guided audio generation\"](https://arxiv.org/abs/2209.15352) as the model architecture is similar to MusicGen with a smaller frame rate and multiple streams of tokens, allowing to reduce generation time." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from audiocraft.models import AudioGen\n", + "\n", + "model = AudioGen.get_pretrained('facebook/audiogen-medium')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, let us configure the generation parameters. Specifically, you can control the following:\n", + "* `use_sampling` (bool, optional): use sampling if True, else do argmax decoding. Defaults to True.\n", + "* `top_k` (int, optional): top_k used for sampling. Defaults to 250.\n", + "* `top_p` (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.\n", + "* `temperature` (float, optional): softmax temperature parameter. Defaults to 1.0.\n", + "* `duration` (float, optional): duration of the generated waveform. Defaults to 10.0.\n", + "* `cfg_coef` (float, optional): coefficient used for classifier free guidance. Defaults to 3.0.\n", + "\n", + "When left unchanged, AudioGen will revert to its default parameters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.set_generation_params(\n", + " use_sampling=True,\n", + " top_k=250,\n", + " duration=5\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we can go ahead and start generating sound using one of the following modes:\n", + "* Audio continuation using `model.generate_continuation`\n", + "* Text-conditional samples using `model.generate`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Audio Continuation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "import torchaudio\n", + "import torch\n", + "from audiocraft.utils.notebook import display_audio\n", + "\n", + "def get_bip_bip(bip_duration=0.125, frequency=440,\n", + " duration=0.5, sample_rate=16000, device=\"cuda\"):\n", + " \"\"\"Generates a series of bip bip at the given frequency.\"\"\"\n", + " t = torch.arange(\n", + " int(duration * sample_rate), device=\"cuda\", dtype=torch.float) / sample_rate\n", + " wav = torch.cos(2 * math.pi * 440 * t)[None]\n", + " tp = (t % (2 * bip_duration)) / (2 * bip_duration)\n", + " envelope = (tp >= 0.5).float()\n", + " return wav * envelope" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Here we use a synthetic signal to prompt the generated audio.\n", + "res = model.generate_continuation(\n", + " get_bip_bip(0.125).expand(2, -1, -1), \n", + " 16000, ['Whistling with wind blowing', \n", + " 'Typing on a typewriter'], \n", + " progress=True)\n", + "display_audio(res, 16000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# You can also use any audio from a file. Make sure to trim the file if it is too long!\n", + "prompt_waveform, prompt_sr = torchaudio.load(\"../assets/sirens_and_a_humming_engine_approach_and_pass.mp3\")\n", + "prompt_duration = 2\n", + "prompt_waveform = prompt_waveform[..., :int(prompt_duration * prompt_sr)]\n", + "output = model.generate_continuation(prompt_waveform, prompt_sample_rate=prompt_sr, progress=True)\n", + "display_audio(output, sample_rate=16000)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Text-conditional Generation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from audiocraft.utils.notebook import display_audio\n", + "\n", + "output = model.generate(\n", + " descriptions=[\n", + " 'Subway train blowing its horn',\n", + " 'A cat meowing',\n", + " ],\n", + " progress=True\n", + ")\n", + "display_audio(output, sample_rate=16000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/app.py b/demos/musicgen_app.py similarity index 55% rename from app.py rename to demos/musicgen_app.py index 0f92495d323f1c70a9c8dde3b7680e3f9491ab83..adfd7f00f693ace67773b4b0595a3069d5d194c4 100644 --- a/app.py +++ b/demos/musicgen_app.py @@ -9,33 +9,40 @@ import argparse from concurrent.futures import ProcessPoolExecutor +import logging import os from pathlib import Path import subprocess as sp +import sys from tempfile import NamedTemporaryFile import time import typing as tp import warnings +from einops import rearrange import torch import gradio as gr from audiocraft.data.audio_utils import convert_audio from audiocraft.data.audio import audio_write -from audiocraft.models import MusicGen +from audiocraft.models.encodec import InterleaveStereoCompressionModel +from audiocraft.models import MusicGen, MultiBandDiffusion MODEL = None # Last used model -IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '') -MAX_BATCH_SIZE = 6 +SPACE_ID = os.environ.get('SPACE_ID', '') +IS_BATCHED = "facebook/MusicGen" in SPACE_ID or 'musicgen-internal/musicgen_dev' in SPACE_ID +print(IS_BATCHED) +MAX_BATCH_SIZE = 12 BATCHED_DURATION = 15 INTERRUPTING = False +MBD = None # We have to wrap subprocess call to clean a bit the log when using gr.make_waveform _old_call = sp.call def _call_nostderr(*args, **kwargs): - # Avoid ffmpeg vomitting on the logs. + # Avoid ffmpeg vomiting on the logs. kwargs['stderr'] = sp.DEVNULL kwargs['stdout'] = sp.DEVNULL _old_call(*args, **kwargs) @@ -43,7 +50,7 @@ def _call_nostderr(*args, **kwargs): sp.call = _call_nostderr # Preallocating the pool of processes. -pool = ProcessPoolExecutor(3) +pool = ProcessPoolExecutor(4) pool.__enter__() @@ -85,14 +92,23 @@ def make_waveform(*args, **kwargs): return out -def load_model(version='melody'): +def load_model(version='facebook/musicgen-melody'): global MODEL print("Loading model", version) if MODEL is None or MODEL.name != version: + del MODEL + MODEL = None # in case loading would crash MODEL = MusicGen.get_pretrained(version) -def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs): +def load_diffusion(): + global MBD + if MBD is None: + print("loading MBD") + MBD = MultiBandDiffusion.get_mbd_musicgen() + + +def _do_predictions(texts, melodies, duration, progress=False, gradio_progress=None, **gen_kwargs): MODEL.set_generation_params(duration=duration, **gen_kwargs) print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies]) be = time.time() @@ -110,44 +126,71 @@ def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs): melody = convert_audio(melody, sr, target_sr, target_ac) processed_melodies.append(melody) - if any(m is not None for m in processed_melodies): - outputs = MODEL.generate_with_chroma( - descriptions=texts, - melody_wavs=processed_melodies, - melody_sample_rate=target_sr, - progress=progress, - ) - else: - outputs = MODEL.generate(texts, progress=progress) - + try: + if any(m is not None for m in processed_melodies): + outputs = MODEL.generate_with_chroma( + descriptions=texts, + melody_wavs=processed_melodies, + melody_sample_rate=target_sr, + progress=progress, + return_tokens=USE_DIFFUSION + ) + else: + outputs = MODEL.generate(texts, progress=progress, return_tokens=USE_DIFFUSION) + except RuntimeError as e: + raise gr.Error("Error while generating " + e.args[0]) + if USE_DIFFUSION: + if gradio_progress is not None: + gradio_progress(1, desc='Running MultiBandDiffusion...') + tokens = outputs[1] + if isinstance(MODEL.compression_model, InterleaveStereoCompressionModel): + left, right = MODEL.compression_model.get_left_right_codes(tokens) + tokens = torch.cat([left, right]) + outputs_diffusion = MBD.tokens_to_wav(tokens) + if isinstance(MODEL.compression_model, InterleaveStereoCompressionModel): + assert outputs_diffusion.shape[1] == 1 # output is mono + outputs_diffusion = rearrange(outputs_diffusion, '(s b) c t -> b (s c) t', s=2) + outputs = torch.cat([outputs[0], outputs_diffusion], dim=0) outputs = outputs.detach().cpu().float() - out_files = [] + pending_videos = [] + out_wavs = [] for output in outputs: with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file: audio_write( file.name, output, MODEL.sample_rate, strategy="loudness", loudness_headroom_db=16, loudness_compressor=True, add_suffix=False) - out_files.append(pool.submit(make_waveform, file.name)) + pending_videos.append(pool.submit(make_waveform, file.name)) + out_wavs.append(file.name) file_cleaner.add(file.name) - res = [out_file.result() for out_file in out_files] - for file in res: - file_cleaner.add(file) + out_videos = [pending_video.result() for pending_video in pending_videos] + for video in out_videos: + file_cleaner.add(video) print("batch finished", len(texts), time.time() - be) print("Tempfiles currently stored: ", len(file_cleaner.files)) - return res + return out_videos, out_wavs def predict_batched(texts, melodies): max_text_length = 512 texts = [text[:max_text_length] for text in texts] - load_model('melody') + load_model('facebook/musicgen-stereo-melody') res = _do_predictions(texts, melodies, BATCHED_DURATION) - return [res] + return res -def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()): +def predict_full(model, model_path, decoder, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()): global INTERRUPTING + global USE_DIFFUSION INTERRUPTING = False + progress(0, desc="Loading model...") + model_path = model_path.strip() + if model_path: + if not Path(model_path).exists(): + raise gr.Error(f"Model path {model_path} doesn't exist.") + if not Path(model_path).is_dir(): + raise gr.Error(f"Model path {model_path} must be a folder containing " + "state_dict.bin and compression_state_dict_.bin.") + model = model_path if temperature < 0: raise gr.Error("Temperature must be >= 0.") if topk < 0: @@ -156,18 +199,31 @@ def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coe raise gr.Error("Topp must be non-negative.") topk = int(topk) + if decoder == "MultiBand_Diffusion": + USE_DIFFUSION = True + progress(0, desc="Loading diffusion model...") + load_diffusion() + else: + USE_DIFFUSION = False load_model(model) + max_generated = 0 + def _progress(generated, to_generate): - progress((generated, to_generate)) + nonlocal max_generated + max_generated = max(generated, max_generated) + progress((min(max_generated, to_generate), to_generate)) if INTERRUPTING: raise gr.Error("Interrupted.") MODEL.set_custom_progress_callback(_progress) - outs = _do_predictions( + videos, wavs = _do_predictions( [text], [melody], duration, progress=True, - top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef) - return outs[0] + top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef, + gradio_progress=progress) + if USE_DIFFUSION: + return videos[0], wavs[0], videos[1], wavs[1] + return videos[0], wavs[0], None, None def toggle_audio_src(choice): @@ -177,6 +233,13 @@ def toggle_audio_src(choice): return gr.update(source="upload", value=None, label="File") +def toggle_diffusion(choice): + if choice == "MultiBand_Diffusion": + return [gr.update(visible=True)] * 2 + else: + return [gr.update(visible=False)] * 2 + + def ui_full(launch_kwargs): with gr.Blocks() as interface: gr.Markdown( @@ -201,8 +264,16 @@ def ui_full(launch_kwargs): # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license. _ = gr.Button("Interrupt").click(fn=interrupt, queue=False) with gr.Row(): - model = gr.Radio(["melody", "medium", "small", "large"], - label="Model", value="melody", interactive=True) + model = gr.Radio(["facebook/musicgen-melody", "facebook/musicgen-medium", "facebook/musicgen-small", + "facebook/musicgen-large", "facebook/musicgen-melody-large", + "facebook/musicgen-stereo-small", "facebook/musicgen-stereo-medium", + "facebook/musicgen-stereo-melody", "facebook/musicgen-stereo-large", + "facebook/musicgen-stereo-melody-large"], + label="Model", value="facebook/musicgen-stereo-melody", interactive=True) + model_path = gr.Text(label="Model Path (custom models)") + with gr.Row(): + decoder = gr.Radio(["Default", "MultiBand_Diffusion"], + label="Decoder", value="Default", interactive=True) with gr.Row(): duration = gr.Slider(minimum=1, maximum=120, value=10, label="Duration", interactive=True) with gr.Row(): @@ -212,40 +283,56 @@ def ui_full(launch_kwargs): cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True) with gr.Column(): output = gr.Video(label="Generated Music") - submit.click(predict_full, - inputs=[model, text, melody, duration, topk, topp, temperature, cfg_coef], - outputs=[output]) + audio_output = gr.Audio(label="Generated Music (wav)", type='filepath') + diffusion_output = gr.Video(label="MultiBand Diffusion Decoder") + audio_diffusion = gr.Audio(label="MultiBand Diffusion Decoder (wav)", type='filepath') + submit.click(toggle_diffusion, decoder, [diffusion_output, audio_diffusion], queue=False, + show_progress=False).then(predict_full, inputs=[model, model_path, decoder, text, melody, duration, topk, topp, + temperature, cfg_coef], + outputs=[output, audio_output, diffusion_output, audio_diffusion]) radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False) + gr.Examples( fn=predict_full, examples=[ [ "An 80s driving pop song with heavy drums and synth pads in the background", "./assets/bach.mp3", - "melody" + "facebook/musicgen-stereo-melody", + "Default" ], [ "A cheerful country song with acoustic guitars", "./assets/bolero_ravel.mp3", - "melody" + "facebook/musicgen-stereo-melody", + "Default" ], [ "90s rock song with electric guitar and heavy drums", None, - "medium" + "facebook/musicgen-stereo-medium", + "Default" ], [ "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions", "./assets/bach.mp3", - "melody" + "facebook/musicgen-stereo-melody", + "Default" ], [ "lofi slow bpm electro chill with organic samples", None, - "medium", + "facebook/musicgen-stereo-medium", + "Default" + ], + [ + "Punk rock with loud drum and power guitar", + None, + "facebook/musicgen-stereo-medium", + "MultiBand_Diffusion" ], ], - inputs=[text, melody, model], + inputs=[text, melody, model, decoder], outputs=[output] ) gr.Markdown( @@ -253,8 +340,18 @@ def ui_full(launch_kwargs): ### More details The model will generate a short music extract based on the description you provided. - The model can generate up to 30 seconds of audio in one pass. It is now possible - to extend the generation by feeding back the end of the previous chunk of audio. + The model can generate up to 30 seconds of audio in one pass. + + The model was trained with description from a stock music catalog, descriptions that will work best + should include some level of details on the instruments present, along with some intended use case + (e.g. adding "perfect for a commercial" can somehow help). + + Using one of the `melody` model (e.g. `musicgen-melody-*`), you can optionally provide a reference audio + from which a broad melody will be extracted. + The model will then try to follow both the description and melody provided. + For best results, the melody should be 30 seconds long (I know, the samples we provide are not...) + + It is now possible to extend the generation by feeding back the end of the previous chunk of audio. This can take a long time, and the model might lose consistency. The model might also decide at arbitrary positions that the song ends. @@ -262,19 +359,23 @@ def ui_full(launch_kwargs): An overlap of 12 seconds is kept with the previously generated chunk, and 18 "new" seconds are generated each time. - We present 4 model variations: - 1. Melody -- a music generation model capable of generating music condition + We present 10 model variations: + 1. facebook/musicgen-melody -- a music generation model capable of generating music condition on text and melody inputs. **Note**, you can also use text only. - 2. Small -- a 300M transformer decoder conditioned on text only. - 3. Medium -- a 1.5B transformer decoder conditioned on text only. - 4. Large -- a 3.3B transformer decoder conditioned on text only (might OOM for the longest sequences.) - - When using `melody`, ou can optionaly provide a reference audio from - which a broad melody will be extracted. The model will then try to follow both - the description and melody provided. - - You can also use your own GPU or a Google Colab by following the instructions on our repo. - See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft) + 2. facebook/musicgen-small -- a 300M transformer decoder conditioned on text only. + 3. facebook/musicgen-medium -- a 1.5B transformer decoder conditioned on text only. + 4. facebook/musicgen-large -- a 3.3B transformer decoder conditioned on text only. + 5. facebook/musicgen-melody-large -- a 3.3B transformer decoder conditioned on and melody. + 6. facebook/musicgen-stereo-*: same as the previous models but fine tuned to output stereo audio. + + We also present two way of decoding the audio tokens + 1. Use the default GAN based compression model. It can suffer from artifacts especially + for crashes, snares etc. + 2. Use [MultiBand Diffusion](https://arxiv.org/abs/2308.02560). Should improve the audio quality, + at an extra computational cost. When this is selected, we provide both the GAN based decoded + audio, and the one obtained with MBD. + + See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md) for more details. """ ) @@ -288,7 +389,7 @@ def ui_batched(launch_kwargs): """ # MusicGen - This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft), + This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md), a simple and controllable model for music generation presented at: ["Simple and Controllable Music Generation"](https://huggingface.co./papers/2306.05284).
@@ -312,8 +413,9 @@ def ui_batched(launch_kwargs): submit = gr.Button("Generate") with gr.Column(): output = gr.Video(label="Generated Music") + audio_output = gr.Audio(label="Generated Music (wav)", type='filepath') submit.click(predict_batched, inputs=[text, melody], - outputs=[output], batch=True, max_batch_size=MAX_BATCH_SIZE) + outputs=[output, audio_output], batch=True, max_batch_size=MAX_BATCH_SIZE) radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False) gr.Examples( fn=predict_batched, @@ -345,15 +447,27 @@ def ui_batched(launch_kwargs): gr.Markdown(""" ### More details - The model will generate 12 seconds of audio based on the description you provided. - You can optionaly provide a reference audio from which a broad melody will be extracted. + The model will generate 15 seconds of audio based on the description you provided. + The model was trained with description from a stock music catalog, descriptions that will work best + should include some level of details on the instruments present, along with some intended use case + (e.g. adding "perfect for a commercial" can somehow help). + + You can optionally provide a reference audio from which a broad melody will be extracted. The model will then try to follow both the description and melody provided. - All samples are generated with the `melody` model. + For best results, the melody should be 30 seconds long (I know, the samples we provide are not...) - You can also use your own GPU or a Google Colab by following the instructions on our repo. + You can access more control (longer generation, more models etc.) by clicking + the + Duplicate Space + (you will then need a paid GPU from HuggingFace). + If you have a GPU, you can run the gradio demo locally (click the link to our repo below for more info). + Finally, you can get a GPU for free from Google + and run the demo in [a Google Colab.](https://ai.honu.io/red/musicgen-colab). - See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft) - for more details. + See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md) + for more details. All samples are generated with the `stereo-melody` model. """) demo.queue(max_size=8 * 4).launch(**launch_kwargs) @@ -400,8 +514,12 @@ if __name__ == "__main__": if args.share: launch_kwargs['share'] = args.share + logging.basicConfig(level=logging.INFO, stream=sys.stderr) + # Show the interface if IS_BATCHED: + global USE_DIFFUSION + USE_DIFFUSION = False ui_batched(launch_kwargs) else: ui_full(launch_kwargs) diff --git a/demo.ipynb b/demos/musicgen_demo.ipynb similarity index 70% rename from demo.ipynb rename to demos/musicgen_demo.ipynb index 4ea3d3643cbbf562d520e9c4cf7d71e479a8b50a..f8deacd90702c1164f5977ed68d0d89a2d222dbb 100644 --- a/demo.ipynb +++ b/demos/musicgen_demo.ipynb @@ -8,24 +8,28 @@ "Welcome to MusicGen's demo jupyter notebook. Here you will find a series of self-contained examples of how to use MusicGen in different settings.\n", "\n", "First, we start by initializing MusicGen, you can choose a model from the following selection:\n", - "1. `small` - 300M transformer decoder.\n", - "2. `medium` - 1.5B transformer decoder.\n", - "3. `melody` - 1.5B transformer decoder also supporting melody conditioning.\n", - "4. `large` - 3.3B transformer decoder.\n", + "1. `facebook/musicgen-small` - 300M transformer decoder.\n", + "2. `facebook/musicgen-medium` - 1.5B transformer decoder.\n", + "3. `facebook/musicgen-melody` - 1.5B transformer decoder also supporting melody conditioning.\n", + "4. `facebook/musicgen-large` - 3.3B transformer decoder.\n", "\n", - "We will use the `small` variant for the purpose of this demonstration." + "We will use the `facebook/musicgen-small` variant for the purpose of this demonstration." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from audiocraft.models import MusicGen\n", + "from audiocraft.models import MultiBandDiffusion\n", "\n", + "USE_DIFFUSION_DECODER = False\n", "# Using small model, better results would be obtained with `medium` or `large`.\n", - "model = MusicGen.get_pretrained('small')" + "model = MusicGen.get_pretrained('facebook/musicgen-small')\n", + "if USE_DIFFUSION_DECODER:\n", + " mbd = MultiBandDiffusion.get_mbd_musicgen()" ] }, { @@ -52,7 +56,7 @@ "model.set_generation_params(\n", " use_sampling=True,\n", " top_k=250,\n", - " duration=5\n", + " duration=30\n", ")" ] }, @@ -67,25 +71,6 @@ "* Melody-conditional samples using `model.generate_with_chroma`" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Unconditional Generation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from audiocraft.utils.notebook import display_audio\n", - "\n", - "output = model.generate_unconditional(num_samples=2, progress=True)\n", - "display_audio(output, sample_rate=32000)" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -112,7 +97,7 @@ " wav = torch.cos(2 * math.pi * 440 * t)[None]\n", " tp = (t % (2 * bip_duration)) / (2 * bip_duration)\n", " envelope = (tp >= 0.5).float()\n", - " return wav * envelope\n" + " return wav * envelope" ] }, { @@ -138,11 +123,14 @@ "outputs": [], "source": [ "# You can also use any audio from a file. Make sure to trim the file if it is too long!\n", - "prompt_waveform, prompt_sr = torchaudio.load(\"./assets/bach.mp3\")\n", + "prompt_waveform, prompt_sr = torchaudio.load(\"../assets/bach.mp3\")\n", "prompt_duration = 2\n", "prompt_waveform = prompt_waveform[..., :int(prompt_duration * prompt_sr)]\n", - "output = model.generate_continuation(prompt_waveform, prompt_sample_rate=prompt_sr, progress=True)\n", - "display_audio(output, sample_rate=32000)" + "output = model.generate_continuation(prompt_waveform, prompt_sample_rate=prompt_sr, progress=True, return_tokens=True)\n", + "display_audio(output[0], sample_rate=32000)\n", + "if USE_DIFFUSION_DECODER:\n", + " out_diffusion = mbd.tokens_to_wav(output[1])\n", + " display_audio(out_diffusion, sample_rate=32000)" ] }, { @@ -162,12 +150,20 @@ "\n", "output = model.generate(\n", " descriptions=[\n", - " '80s pop track with bassy drums and synth',\n", - " '90s rock song with loud guitars and heavy drums',\n", + " #'80s pop track with bassy drums and synth',\n", + " #'90s rock song with loud guitars and heavy drums',\n", + " #'Progressive rock drum and bass solo',\n", + " #'Punk Rock song with loud drum and power guitar',\n", + " #'Bluesy guitar instrumental with soulful licks and a driving rhythm section',\n", + " #'Jazz Funk song with slap bass and powerful saxophone',\n", + " 'drum and bass beat with intense percussions'\n", " ],\n", - " progress=True\n", + " progress=True, return_tokens=True\n", ")\n", - "display_audio(output, sample_rate=32000)" + "display_audio(output[0], sample_rate=32000)\n", + "if USE_DIFFUSION_DECODER:\n", + " out_diffusion = mbd.tokens_to_wav(output[1])\n", + " display_audio(out_diffusion, sample_rate=32000)" ] }, { @@ -186,10 +182,10 @@ "import torchaudio\n", "from audiocraft.utils.notebook import display_audio\n", "\n", - "model = MusicGen.get_pretrained('melody')\n", + "model = MusicGen.get_pretrained('facebook/musicgen-melody')\n", "model.set_generation_params(duration=8)\n", "\n", - "melody_waveform, sr = torchaudio.load(\"assets/bach.mp3\")\n", + "melody_waveform, sr = torchaudio.load(\"../assets/bach.mp3\")\n", "melody_waveform = melody_waveform.unsqueeze(0).repeat(2, 1, 1)\n", "output = model.generate_with_chroma(\n", " descriptions=[\n", @@ -198,17 +194,13 @@ " ],\n", " melody_wavs=melody_waveform,\n", " melody_sample_rate=sr,\n", - " progress=True\n", + " progress=True, return_tokens=True\n", ")\n", - "display_audio(output, sample_rate=32000)" + "display_audio(output[0], sample_rate=32000)\n", + "if USE_DIFFUSION_DECODER:\n", + " out_diffusion = mbd.tokens_to_wav(output[1])\n", + " display_audio(out_diffusion, sample_rate=32000)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -227,7 +219,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.9.16" + }, + "vscode": { + "interpreter": { + "hash": "b02c911f9b3627d505ea4a19966a915ef21f28afb50dbf6b2115072d27c69103" + } } }, "nbformat": 4, diff --git a/docs/AUDIOGEN.md b/docs/AUDIOGEN.md new file mode 100644 index 0000000000000000000000000000000000000000..a0ff481190fb52fe865aa66aaaa10176f7cf995c --- /dev/null +++ b/docs/AUDIOGEN.md @@ -0,0 +1,158 @@ +# AudioGen: Textually-guided audio generation + +AudioCraft provides the code and a model re-implementing AudioGen, a [textually-guided audio generation][audiogen_arxiv] +model that performs text-to-sound generation. + +The provided AudioGen reimplementation follows the LM model architecture introduced in [MusicGen][musicgen_arxiv] +and is a single stage auto-regressive Transformer model trained over a 16kHz +EnCodec tokenizer with 4 codebooks sampled at 50 Hz. +This model variant reaches similar audio quality than the original implementation introduced in the AudioGen publication +while providing faster generation speed given the smaller frame rate. + +**Important note:** The provided models are NOT the original models used to report numbers in the +[AudioGen publication][audiogen_arxiv]. Refer to the model card to learn more about architectural changes. + +Listen to samples from the **original AudioGen implementation** in our [sample page][audiogen_samples]. + + +## Model Card + +See [the model card](../model_cards/AUDIOGEN_MODEL_CARD.md). + + +## Installation + +Please follow the AudioCraft installation instructions from the [README](../README.md). + +AudioCraft requires a GPU with at least 16 GB of memory for running inference with the medium-sized models (~1.5B parameters). + +## API and usage + +We provide a simple API and 1 pre-trained models for AudioGen: + +`facebook/audiogen-medium`: 1.5B model, text to sound - [🤗 Hub](https://huggingface.co./facebook/audiogen-medium) + +You can play with AudioGen by running the jupyter notebook at [`demos/audiogen_demo.ipynb`](../demos/audiogen_demo.ipynb) locally (if you have a GPU). + +See after a quick example for using the API. + +```python +import torchaudio +from audiocraft.models import AudioGen +from audiocraft.data.audio import audio_write + +model = AudioGen.get_pretrained('facebook/audiogen-medium') +model.set_generation_params(duration=5) # generate 5 seconds. +descriptions = ['dog barking', 'sirene of an emergency vehicle', 'footsteps in a corridor'] +wav = model.generate(descriptions) # generates 3 samples. + +for idx, one_wav in enumerate(wav): + # Will save under {idx}.wav, with loudness normalization at -14 db LUFS. + audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True) +``` + +## Training + +The [AudioGenSolver](../audiocraft/solvers/audiogen.py) implements the AudioGen's training pipeline +used to develop the released model. Note that this may not fully reproduce the results presented in the paper. +Similarly to MusicGen, it defines an autoregressive language modeling task over multiple streams of +discrete tokens extracted from a pre-trained EnCodec model (see [EnCodec documentation](./ENCODEC.md) +for more details on how to train such model) with dataset-specific changes for environmental sound +processing. + +Note that **we do NOT provide any of the datasets** used for training AudioGen. + +### Example configurations and grids + +We provide configurations to reproduce the released models and our research. +AudioGen solvers configuration are available in [config/solver/audiogen](../config/solver/audiogen). +The base training configuration used for the released models is the following: +[`solver=audiogen/audiogen_base_16khz`](../config/solver/audiogen/audiogen_base_16khz.yaml) + +Please find some example grids to train AudioGen at +[audiocraft/grids/audiogen](../audiocraft/grids/audiogen/). + +```shell +# text-to-sound +dora grid audiogen.audiogen_base_16khz +``` + +### Sound dataset and metadata + +AudioGen's underlying dataset is an AudioDataset augmented with description metadata. +The AudioGen dataset implementation expects the metadata to be available as `.json` files +at the same location as the audio files or through specified external folder. +Learn more in the [datasets section](./DATASETS.md). + +### Evaluation stage + +By default, evaluation stage is also computing the cross-entropy and the perplexity over the +evaluation dataset. Indeed the objective metrics used for evaluation can be costly to run +or require some extra dependencies. Please refer to the [metrics documentation](./METRICS.md) +for more details on the requirements for each metric. + +We provide an off-the-shelf configuration to enable running the objective metrics +for audio generation in +[config/solver/audiogen/evaluation/objective_eval](../config/solver/audiogen/evaluation/objective_eval.yaml). + +One can then activate evaluation the following way: +```shell +# using the configuration +dora run solver=audiogen/debug solver/audiogen/evaluation=objective_eval +# specifying each of the fields, e.g. to activate KL computation +dora run solver=audiogen/debug evaluate.metrics.kld=true +``` + +See [an example evaluation grid](../audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py). + +### Generation stage + +The generation stage allows to generate samples conditionally and/or unconditionally and to perform +audio continuation (from a prompt). We currently support greedy sampling (argmax), sampling +from softmax with a given temperature, top-K and top-P (nucleus) sampling. The number of samples +generated and the batch size used are controlled by the `dataset.generate` configuration +while the other generation parameters are defined in `generate.lm`. + +```shell +# control sampling parameters +dora run solver=audiogen/debug generate.lm.gen_duration=5 generate.lm.use_sampling=true generate.lm.top_k=15 +``` + +## More information + +Refer to [MusicGen's instructions](./MUSICGEN.md). + +### Learn more + +Learn more about AudioCraft training pipelines in the [dedicated section](./TRAINING.md). + + +## Citation + +AudioGen +``` +@article{kreuk2022audiogen, + title={Audiogen: Textually guided audio generation}, + author={Kreuk, Felix and Synnaeve, Gabriel and Polyak, Adam and Singer, Uriel and D{\'e}fossez, Alexandre and Copet, Jade and Parikh, Devi and Taigman, Yaniv and Adi, Yossi}, + journal={arXiv preprint arXiv:2209.15352}, + year={2022} +} +``` + +MusicGen +``` +@article{copet2023simple, + title={Simple and Controllable Music Generation}, + author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez}, + year={2023}, + journal={arXiv preprint arXiv:2306.05284}, +} +``` + +## License + +See license information in the [model card](../model_cards/AUDIOGEN_MODEL_CARD.md). + +[audiogen_arxiv]: https://arxiv.org/abs/2209.15352 +[musicgen_arxiv]: https://arxiv.org/abs/2306.05284 +[audiogen_samples]: https://felixkreuk.github.io/audiogen/ diff --git a/docs/CONDITIONING.md b/docs/CONDITIONING.md new file mode 100644 index 0000000000000000000000000000000000000000..6e356cb8e9912d3e18fc84598c1acf77c6e7abc5 --- /dev/null +++ b/docs/CONDITIONING.md @@ -0,0 +1,146 @@ +# AudioCraft conditioning modules + +AudioCraft provides a +[modular implementation of conditioning modules](../audiocraft/modules/conditioners.py) +that can be used with the language model to condition the generation. +The codebase was developed in order to easily extend the set of modules +currently supported to easily develop new ways of controlling the generation. + + +## Conditioning methods + +For now, we support 3 main types of conditioning within AudioCraft: +* Text-based conditioning methods +* Waveform-based conditioning methods +* Joint embedding conditioning methods for text and audio projected in a shared latent space. + +The Language Model relies on 2 core components that handle processing information: +* The `ConditionProvider` class, that maps metadata to processed conditions leveraging +all the defined conditioners for the given task. +* The `ConditionFuser` class, that takes preprocessed conditions and properly fuse the +conditioning embedding to the language model inputs following a given fusing strategy. + +Different conditioners (for text, waveform, joint embeddings...) are provided as torch +modules in AudioCraft and are used internally in the language model to process the +conditioning signals and feed them to the language model. + + +## Core concepts + +### Conditioners + +The `BaseConditioner` torch module is the base implementation for all conditioners in audiocraft. + +Each conditioner is expected to implement 2 methods: +* The `tokenize` method that is used as a preprocessing method that contains all processing +that can lead to synchronization points (e.g. BPE tokenization with transfer to the GPU). +The output of the tokenize method will then be used to feed the forward method. +* The `forward` method that takes the output of the tokenize method and contains the core computation +to obtain the conditioning embedding along with a mask indicating valid indices (e.g. padding tokens). + +### ConditionProvider + +The ConditionProvider prepares and provides conditions given a dictionary of conditioners. + +Conditioners are specified as a dictionary of attributes and the corresponding conditioner +providing the processing logic for the given attribute. + +Similarly to the conditioners, the condition provider works in two steps to avoid sychronization points: +* A `tokenize` method that takes a list of conditioning attributes for the batch, +and run all tokenize steps for the set of conditioners. +* A `forward` method that takes the output of the tokenize step and run all the forward steps +for the set of conditioners. + +The list of conditioning attributes is passed as a list of `ConditioningAttributes` +that is presented just below. + +### ConditionFuser + +Once all conditioning signals have been extracted and processed by the `ConditionProvider` +as dense embeddings, they remain to be passed to the language model along with the original +language model inputs. + +The `ConditionFuser` handles specifically the logic to combine the different conditions +to the actual model input, supporting different strategies to combine them. + +One can therefore define different strategies to combine or fuse the condition to the input, in particular: +* Prepending the conditioning signal to the input with the `prepend` strategy, +* Summing the conditioning signal to the input with the `sum` strategy, +* Combining the conditioning relying on a cross-attention mechanism with the `cross` strategy, +* Using input interpolation with the `input_interpolate` strategy. + +### SegmentWithAttributes and ConditioningAttributes: From metadata to conditions + +The `ConditioningAttributes` dataclass is the base class for metadata +containing all attributes used for conditioning the language model. + +It currently supports the following types of attributes: +* Text conditioning attributes: Dictionary of textual attributes used for text-conditioning. +* Wav conditioning attributes: Dictionary of waveform attributes used for waveform-based +conditioning such as the chroma conditioning. +* JointEmbed conditioning attributes: Dictionary of text and waveform attributes +that are expected to be represented in a shared latent space. + +These different types of attributes are the attributes that are processed +by the different conditioners. + +`ConditioningAttributes` are extracted from metadata loaded along the audio in the datasets, +provided that the metadata used by the dataset implements the `SegmentWithAttributes` abstraction. + +All metadata-enabled datasets to use for conditioning in AudioCraft inherits +the [`audiocraft.data.info_dataset.InfoAudioDataset`](../audiocraft/data/info_audio_dataset.py) class +and the corresponding metadata inherits and implements the `SegmentWithAttributes` abstraction. +Refer to the [`audiocraft.data.music_dataset.MusicAudioDataset`](../audiocraft/data/music_dataset.py) +class as an example. + + +## Available conditioners + +### Text conditioners + +All text conditioners are expected to inherit from the `TextConditioner` class. + +AudioCraft currently provides two text conditioners: +* The `LUTConditioner` that relies on look-up-table of embeddings learned at train time, +and relying on either no tokenizer or a spacy tokenizer. This conditioner is particularly +useful for simple experiments and categorical labels. +* The `T5Conditioner` that relies on a +[pre-trained T5 model](https://huggingface.co./docs/transformers/model_doc/t5) +frozen or fine-tuned at train time to extract the text embeddings. + +### Waveform conditioners + +All waveform conditioners are expected to inherit from the `WaveformConditioner` class and +consists of conditioning method that takes a waveform as input. The waveform conditioner +must implement the logic to extract the embedding from the waveform and define the downsampling +factor from the waveform to the resulting embedding. + +The `ChromaStemConditioner` conditioner is a waveform conditioner for the chroma features +conditioning used by MusicGen. It takes a given waveform, extract relevant stems for melody +(namely all non drums and bass stems) using a +[pre-trained Demucs model](https://github.com/facebookresearch/demucs) +and then extract the chromagram bins from the remaining mix of stems. + +### Joint embeddings conditioners + +We finally provide support for conditioning based on joint text and audio embeddings through +the `JointEmbeddingConditioner` class and the `CLAPEmbeddingConditioner` that implements such +a conditioning method relying on a [pretrained CLAP model](https://github.com/LAION-AI/CLAP). + +## Classifier Free Guidance + +We provide a Classifier Free Guidance implementation in AudioCraft. With the classifier free +guidance dropout, all attributes are dropped with the same probability. + +## Attribute Dropout + +We further provide an attribute dropout strategy. Unlike the classifier free guidance dropout, +the attribute dropout drops given attributes with a defined probability, allowing the model +not to expect all conditioning signals to be provided at once. + +## Faster computation of conditions + +Conditioners that require some heavy computation on the waveform can be cached, in particular +the `ChromaStemConditioner` or `CLAPEmbeddingConditioner`. You just need to provide the +`cache_path` parameter to them. We recommend running dummy jobs for filling up the cache quickly. +An example is provied in the [musicgen.musicgen_melody_32khz grid](../audiocraft/grids/musicgen/musicgen_melody_32khz.py). \ No newline at end of file diff --git a/docs/DATASETS.md b/docs/DATASETS.md new file mode 100644 index 0000000000000000000000000000000000000000..b0890c03cf732450eb498559638c6b45d50e40c3 --- /dev/null +++ b/docs/DATASETS.md @@ -0,0 +1,82 @@ +# AudioCraft datasets + +Our dataset manifest files consist in 1-json-per-line files, potentially gzipped, +as `data.jsons` or `data.jsons.gz` files. This JSON contains the path to the audio +file and associated metadata. The manifest files are then provided in the configuration, +as `datasource` sub-configuration. A datasource contains the pointers to the paths of +the manifest files for each AudioCraft stage (or split) along with additional information +(eg. maximum sample rate to use against this dataset). All the datasources are under the +`dset` group config, with a dedicated configuration file for each dataset. + +## Getting started + +### Example + +See the provided example in the directory that provides a manifest to use the example dataset +provided under the [dataset folder](../dataset/example). + +The manifest files are stored in the [egs folder](../egs/example). + +```shell +egs/ + example/data.json.gz +``` + +A datasource is defined in the configuration folder, in the dset group config for this dataset +at [config/dset/audio/example](../config/dset/audio/example.yaml): + +```shell +# @package __global__ + +datasource: + max_sample_rate: 44100 + max_channels: 2 + + train: egs/example + valid: egs/example + evaluate: egs/example + generate: egs/example +``` + +For proper dataset, one should create manifest for each of the splits and specify the correct path +to the given manifest in the datasource for each split. + +Then, using a dataset through the configuration can be done pointing to the +corresponding dataset configuration: +```shell +dset= # should match the yaml file name + +# for example +dset=audio/example +``` + +### Creating manifest files + +Assuming you want to create manifest files to load with AudioCraft's AudioDataset, you can use +the following command to create new manifest files from a given folder containing audio files: + +```shell +python -m audiocraft.data.audio_dataset egs/my_dataset/my_dataset_split/data.jsonl.gz + +# For example to generate the manifest for dset=audio/example +# note: we don't use any split and we don't compress the jsonl file for this dummy example +python -m audiocraft.data.audio_dataset dataset/example egs/example/data.jsonl + +# More info with: python -m audiocraft.data.audio_dataset --help +``` + +## Additional information + +### MusicDataset and metadata + +The MusicDataset is an AudioDataset with additional metadata. The MusicDataset expects +the additional metadata to be stored in a JSON file that has the same path as the corresponding +audio file, but with a `.json` extension. + +### SoundDataset and metadata + +The SoundDataset is an AudioDataset with descriptions metadata. Similarly to the MusicDataset, +the SoundDataset expects the additional metadata to be stored in a JSON file that has the same +path as the corresponding audio file, but with a `.json` extension. Additionally, the SoundDataset +supports an additional parameter pointing to an extra folder `external_metadata_source` containing +all the JSON metadata files given they have the same filename as the audio file. diff --git a/docs/ENCODEC.md b/docs/ENCODEC.md new file mode 100644 index 0000000000000000000000000000000000000000..efc2bcc7ec50190b907c887b920b70fd799c6953 --- /dev/null +++ b/docs/ENCODEC.md @@ -0,0 +1,179 @@ +# EnCodec: High Fidelity Neural Audio Compression + +AudioCraft provides the training code for EnCodec, a state-of-the-art deep learning +based audio codec supporting both mono stereo audio, presented in the +[High Fidelity Neural Audio Compression][arxiv] paper. +Check out our [sample page][encodec_samples]. + +## Original EnCodec models + +The EnCodec models presented in High Fidelity Neural Audio Compression can be accessed +and used with the [EnCodec repository](https://github.com/facebookresearch/encodec). + +**Note**: We do not guarantee compatibility between the AudioCraft and EnCodec codebases +and released checkpoints at this stage. + + +## Installation + +Please follow the AudioCraft installation instructions from the [README](../README.md). + + +## Training + +The [CompressionSolver](../audiocraft/solvers/compression.py) implements the audio reconstruction +task to train an EnCodec model. Specifically, it trains an encoder-decoder with a quantization +bottleneck - a SEANet encoder-decoder with Residual Vector Quantization bottleneck for EnCodec - +using a combination of objective and perceptual losses in the forms of discriminators. + +The default configuration matches a causal EnCodec training with at a single bandwidth. + +### Example configuration and grids + +We provide sample configuration and grids for training EnCodec models. + +The compression configuration are defined in +[config/solver/compression](../config/solver/compression). + +The example grids are available at +[audiocraft/grids/compression](../audiocraft/grids/compression). + +```shell +# base causal encodec on monophonic audio sampled at 24 khz +dora grid compression.encodec_base_24khz +# encodec model used for MusicGen on monophonic audio sampled at 32 khz +dora grid compression.encodec_musicgen_32khz +``` + +### Training and valid stages + +The model is trained using a combination of objective and perceptual losses. +More specifically, EnCodec is trained with the MS-STFT discriminator along with +objective losses through the use of a loss balancer to effectively weight +the different losses, in an intuitive manner. + +### Evaluation stage + +Evaluations metrics for audio generation: +* SI-SNR: Scale-Invariant Signal-to-Noise Ratio. +* ViSQOL: Virtual Speech Quality Objective Listener. + +Note: Path to the ViSQOL binary (compiled with bazel) needs to be provided in +order to run the ViSQOL metric on the reference and degraded signals. +The metric is disabled by default. +Please refer to the [metrics documentation](../METRICS.md) to learn more. + +### Generation stage + +The generation stage consists in generating the reconstructed audio from samples +with the current model. The number of samples generated and the batch size used are +controlled by the `dataset.generate` configuration. The output path and audio formats +are defined in the generate stage configuration. + +```shell +# generate samples every 5 epoch +dora run solver=compression/encodec_base_24khz generate.every=5 +# run with a different dset +dora run solver=compression/encodec_base_24khz generate.path= +# limit the number of samples or use a different batch size +dora grid solver=compression/encodec_base_24khz dataset.generate.num_samples=10 dataset.generate.batch_size=4 +``` + +### Playing with the model + +Once you have a model trained, it is possible to get the entire solver, or just +the trained model with the following functions: + +```python +from audiocraft.solvers import CompressionSolver + +# If you trained a custom model with signature SIG. +model = CompressionSolver.model_from_checkpoint('//sig/SIG') +# If you want to get one of the pretrained models with the `//pretrained/` prefix. +model = CompressionSolver.model_from_checkpoint('//pretrained/facebook/encodec_32khz') +# Or load from a custom checkpoint path +model = CompressionSolver.model_from_checkpoint('/my_checkpoints/foo/bar/checkpoint.th') + + +# If you only want to use a pretrained model, you can also directly get it +# from the CompressionModel base model class. +from audiocraft.models import CompressionModel + +# Here do not put the `//pretrained/` prefix! +model = CompressionModel.get_pretrained('facebook/encodec_32khz') +model = CompressionModel.get_pretrained('dac_44khz') + +# Finally, you can also retrieve the full Solver object, with its dataloader etc. +from audiocraft import train +from pathlib import Path +import logging +import os +import sys + +# uncomment the following line if you want some detailed logs when loading a Solver. +logging.basicConfig(stream=sys.stderr, level=logging.INFO) +# You must always run the following function from the root directory. +os.chdir(Path(train.__file__).parent.parent) + + +# You can also get the full solver (only for your own experiments). +# You can provide some overrides to the parameters to make things more convenient. +solver = train.get_solver_from_sig('SIG', {'device': 'cpu', 'dataset': {'batch_size': 8}}) +solver.model +solver.dataloaders +``` + +### Importing / Exporting models + +At the moment we do not have a definitive workflow for exporting EnCodec models, for +instance to Hugging Face (HF). We are working on supporting automatic convertion between +AudioCraft and Hugging Face implementations. + +We still have some support for fine tuning an EnCodec model coming from HF in AudioCraft, +using for instance `continue_from=//pretrained/facebook/encodec_32k`. + +An AudioCraft checkpoint can be exported in a more compact format (excluding the optimizer etc.) +using `audiocraft.utils.export.export_encodec`. For instance, you could run + +```python +from audiocraft.utils import export +from audiocraft import train +xp = train.main.get_xp_from_sig('SIG') +export.export_encodec( + xp.folder / 'checkpoint.th', + '/checkpoints/my_audio_lm/compression_state_dict.bin') + + +from audiocraft.models import CompressionModel +model = CompressionModel.get_pretrained('/checkpoints/my_audio_lm/compression_state_dict.bin') + +from audiocraft.solvers import CompressionSolver +# The two are strictly equivalent, but this function supports also loading from non already exported models. +model = CompressionSolver.model_from_checkpoint('//pretrained//checkpoints/my_audio_lm/compression_state_dict.bin') +``` + +We will see then how to use this model as a tokenizer for MusicGen/Audio gen in the +[MusicGen documentation](./MUSICGEN.md). + +### Learn more + +Learn more about AudioCraft training pipelines in the [dedicated section](./TRAINING.md). + + +## Citation +``` +@article{defossez2022highfi, + title={High Fidelity Neural Audio Compression}, + author={Défossez, Alexandre and Copet, Jade and Synnaeve, Gabriel and Adi, Yossi}, + journal={arXiv preprint arXiv:2210.13438}, + year={2022} +} +``` + + +## License + +See license information in the [README](../README.md). + +[arxiv]: https://arxiv.org/abs/2210.13438 +[encodec_samples]: https://ai.honu.io/papers/encodec/samples.html diff --git a/docs/MBD.md b/docs/MBD.md new file mode 100644 index 0000000000000000000000000000000000000000..b6629184cfb47890632069e3fa68f237c9ae4a43 --- /dev/null +++ b/docs/MBD.md @@ -0,0 +1,117 @@ +# MultiBand Diffusion + +AudioCraft provides the code and models for MultiBand Diffusion, [From Discrete Tokens to High Fidelity Audio using MultiBand Diffusion][arxiv]. +MultiBand diffusion is a collection of 4 models that can decode tokens from +EnCodec tokenizer into waveform audio. You can listen to some examples on the sample page. + + + Open In Colab + +
+ + +## Installation + +Please follow the AudioCraft installation instructions from the [README](../README.md). + + +## Usage + +We offer a number of way to use MultiBand Diffusion: +1. The MusicGen demo includes a toggle to try diffusion decoder. You can use the demo locally by running [`python -m demos.musicgen_app --share`](../demos/musicgen_app.py), or through the [MusicGen Colab](https://colab.research.google.com/drive/1JlTOjB-G0A2Hz3h8PK63vLZk4xdCI5QB?usp=sharing). +2. You can play with MusicGen by running the jupyter notebook at [`demos/musicgen_demo.ipynb`](../demos/musicgen_demo.ipynb) locally (if you have a GPU). + +## API + +We provide a simple API and pre-trained models for MusicGen and for EnCodec at 24 khz for 3 bitrates (1.5 kbps, 3 kbps and 6 kbps). + +See after a quick example for using MultiBandDiffusion with the MusicGen API: + +```python +import torchaudio +from audiocraft.models import MusicGen, MultiBandDiffusion +from audiocraft.data.audio import audio_write + +model = MusicGen.get_pretrained('facebook/musicgen-melody') +mbd = MultiBandDiffusion.get_mbd_musicgen() +model.set_generation_params(duration=8) # generate 8 seconds. +wav, tokens = model.generate_unconditional(4, return_tokens=True) # generates 4 unconditional audio samples and keep the tokens for MBD generation +descriptions = ['happy rock', 'energetic EDM', 'sad jazz'] +wav_diffusion = mbd.tokens_to_wav(tokens) +wav, tokens = model.generate(descriptions, return_tokens=True) # generates 3 samples and keep the tokens. +wav_diffusion = mbd.tokens_to_wav(tokens) +melody, sr = torchaudio.load('./assets/bach.mp3') +# Generates using the melody from the given audio and the provided descriptions, returns audio and audio tokens. +wav, tokens = model.generate_with_chroma(descriptions, melody[None].expand(3, -1, -1), sr, return_tokens=True) +wav_diffusion = mbd.tokens_to_wav(tokens) + +for idx, one_wav in enumerate(wav): + # Will save under {idx}.wav and {idx}_diffusion.wav, with loudness normalization at -14 db LUFS for comparing the methods. + audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True) + audio_write(f'{idx}_diffusion', wav_diffusion[idx].cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True) +``` + +For the compression task (and to compare with [EnCodec](https://github.com/facebookresearch/encodec)): + +```python +import torch +from audiocraft.models import MultiBandDiffusion +from encodec import EncodecModel +from audiocraft.data.audio import audio_read, audio_write + +bandwidth = 3.0 # 1.5, 3.0, 6.0 +mbd = MultiBandDiffusion.get_mbd_24khz(bw=bandwidth) +encodec = EncodecModel.encodec_model_24khz() + +somepath = '' +wav, sr = audio_read(somepath) +with torch.no_grad(): + compressed_encodec = encodec(wav) + compressed_diffusion = mbd.regenerate(wav, sample_rate=sr) + +audio_write('sample_encodec', compressed_encodec.squeeze(0).cpu(), mbd.sample_rate, strategy="loudness", loudness_compressor=True) +audio_write('sample_diffusion', compressed_diffusion.squeeze(0).cpu(), mbd.sample_rate, strategy="loudness", loudness_compressor=True) +``` + + +## Training + +The [DiffusionSolver](../audiocraft/solvers/diffusion.py) implements our diffusion training pipeline. +It generates waveform audio conditioned on the embeddings extracted from a pre-trained EnCodec model +(see [EnCodec documentation](./ENCODEC.md) for more details on how to train such model). + +Note that **we do NOT provide any of the datasets** used for training our diffusion models. +We provide a dummy dataset containing just a few examples for illustrative purposes. + +### Example configurations and grids + +One can train diffusion models as described in the paper by using this [dora grid](../audiocraft/grids/diffusion/4_bands_base_32khz.py). +```shell +# 4 bands MBD trainning +dora grid diffusion.4_bands_base_32khz +``` + +### Learn more + +Learn more about AudioCraft training pipelines in the [dedicated section](./TRAINING.md). + + +## Citation + +``` +@article{sanroman2023fromdi, + title={From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion}, + author={San Roman, Robin and Adi, Yossi and Deleforge, Antoine and Serizel, Romain and Synnaeve, Gabriel and Défossez, Alexandre}, + journal={arXiv preprint arXiv:}, + year={2023} +} +``` + + +## License + +See license information in the [README](../README.md). + + +[arxiv]: https://arxiv.org/abs/2308.02560 +[mbd_samples]: https://ai.honu.io/papers/mbd/ diff --git a/docs/METRICS.md b/docs/METRICS.md new file mode 100644 index 0000000000000000000000000000000000000000..506ce35db708967bb9de6edf9c46df2564f0b0fd --- /dev/null +++ b/docs/METRICS.md @@ -0,0 +1,131 @@ +# AudioCraft objective metrics + +In addition to training losses, AudioCraft provides a set of objective metrics +for audio synthesis and audio generation. As these metrics may require +extra dependencies and can be costly to train, they are often disabled by default. +This section provides guidance for setting up and using these metrics in +the AudioCraft training pipelines. + +## Available metrics + +### Audio synthesis quality metrics + +#### SI-SNR + +We provide an implementation of the Scale-Invariant Signal-to-Noise Ratio in PyTorch. +No specific requirement is needed for this metric. Please activate the metric at the +evaluation stage with the appropriate flag: + +**Warning:** We report the opposite of the SI-SNR, e.g. multiplied by -1. This is due to internal + details where the SI-SNR score can also be used as a training loss function, where lower + values should indicate better reconstruction. Negative values are such expected and a good sign! Those should be again multiplied by `-1` before publication :) + +```shell +dora run <...> evaluate.metrics.sisnr=true +``` + +#### ViSQOL + +We provide a Python wrapper around the ViSQOL [official implementation](https://github.com/google/visqol) +to conveniently run ViSQOL within the training pipelines. + +One must specify the path to the ViSQOL installation through the configuration in order +to enable ViSQOL computations in AudioCraft: + +```shell +# the first parameter is used to activate visqol computation while the second specify +# the path to visqol's library to be used by our python wrapper +dora run <...> evaluate.metrics.visqol=true metrics.visqol.bin= +``` + +See an example grid: [Compression with ViSQOL](../audiocraft/grids/compression/encodec_musicgen_32khz.py) + +To learn more about ViSQOL and how to build ViSQOL binary using bazel, please refer to the +instructions available in the [open source repository](https://github.com/google/visqol). + +### Audio generation metrics + +#### Frechet Audio Distance + +Similarly to ViSQOL, we use a Python wrapper around the Frechet Audio Distance +[official implementation](https://github.com/google-research/google-research/tree/master/frechet_audio_distance) +in TensorFlow. + +Note that we had to make several changes to the actual code in order to make it work. +Please refer to the [FrechetAudioDistanceMetric](../audiocraft/metrics/fad.py) class documentation +for more details. We do not plan to provide further support in obtaining a working setup for the +Frechet Audio Distance at this stage. + +```shell +# the first parameter is used to activate FAD metric computation while the second specify +# the path to FAD library to be used by our python wrapper +dora run <...> evaluate.metrics.fad=true metrics.fad.bin= +``` + +See an example grid: [Evaluation with FAD](../audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py) + +#### Kullback-Leibler Divergence + +We provide a PyTorch implementation of the Kullback-Leibler Divergence computed over the probabilities +of the labels obtained by a state-of-the-art audio classifier. We provide our implementation of the KLD +using the [PaSST classifier](https://github.com/kkoutini/PaSST). + +In order to use the KLD metric over PaSST, you must install the PaSST library as an extra dependency: +```shell +pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt' +``` + +Then similarly, you can use the metric activating the corresponding flag: + +```shell +# one could extend the kld metric with additional audio classifier models that can then be picked through the configuration +dora run <...> evaluate.metrics.kld=true metrics.kld.model=passt +``` + +#### Text consistency + +We provide a text-consistency metric, similarly to the MuLan Cycle Consistency from +[MusicLM](https://arxiv.org/pdf/2301.11325.pdf) or the CLAP score used in +[Make-An-Audio](https://arxiv.org/pdf/2301.12661v1.pdf). +More specifically, we provide a PyTorch implementation of a Text consistency metric +relying on a pre-trained [Contrastive Language-Audio Pretraining (CLAP)](https://github.com/LAION-AI/CLAP). + +Please install the CLAP library as an extra dependency prior to using the metric: +```shell +pip install laion_clap +``` + +Then similarly, you can use the metric activating the corresponding flag: + +```shell +# one could extend the text consistency metric with additional audio classifier models that can then be picked through the configuration +dora run ... evaluate.metrics.text_consistency=true metrics.text_consistency.model=clap +``` + +Note that the text consistency metric based on CLAP will require the CLAP checkpoint to be +provided in the configuration. + +#### Chroma cosine similarity + +Finally, as introduced in MusicGen, we provide a Chroma Cosine Similarity metric in PyTorch. +No specific requirement is needed for this metric. Please activate the metric at the +evaluation stage with the appropriate flag: + +```shell +dora run ... evaluate.metrics.chroma_cosine=true +``` + +#### Comparing against reconstructed audio + +For all the above audio generation metrics, we offer the option to compute the metric on the reconstructed audio +fed in EnCodec instead of the generated sample using the flag `.use_gt=true`. + +## Example usage + +You will find example of configuration for the different metrics introduced above in: +* The [musicgen's default solver](../config/solver/musicgen/default.yaml) for all audio generation metrics +* The [compression's default solver](../config/solver/compression/default.yaml) for all audio synthesis metrics + +Similarly, we provide different examples in our grids: +* [Evaluation with ViSQOL](../audiocraft/grids/compression/encodec_musicgen_32khz.py) +* [Evaluation with FAD and others](../audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py) diff --git a/docs/MUSICGEN.md b/docs/MUSICGEN.md new file mode 100644 index 0000000000000000000000000000000000000000..dc3c1c2efd946552f1ca88eab8f8d00dfbb445c0 --- /dev/null +++ b/docs/MUSICGEN.md @@ -0,0 +1,419 @@ +# MusicGen: Simple and Controllable Music Generation + +AudioCraft provides the code and models for MusicGen, [a simple and controllable model for music generation][arxiv]. +MusicGen is a single stage auto-regressive Transformer model trained over a 32kHz +EnCodec tokenizer with 4 codebooks sampled at 50 Hz. +Unlike existing methods like [MusicLM](https://arxiv.org/abs/2301.11325), MusicGen doesn't require +a self-supervised semantic representation, and it generates all 4 codebooks in one pass. By introducing +a small delay between the codebooks, we show we can predict them in parallel, thus having only 50 auto-regressive +steps per second of audio. +Check out our [sample page][musicgen_samples] or test the available demo! + + + Open In Colab + + + Open in HugginFace + +
+ +We use 20K hours of licensed music to train MusicGen. Specifically, we rely on an internal dataset +of 10K high-quality music tracks, and on the ShutterStock and Pond5 music data. + + +## Model Card + +See [the model card](../model_cards/MUSICGEN_MODEL_CARD.md). + + +## Installation + +Please follow the AudioCraft installation instructions from the [README](../README.md). + +AudioCraft requires a GPU with at least 16 GB of memory for running inference with the medium-sized models (~1.5B parameters). + +## Usage + +We offer a number of way to interact with MusicGen: +1. A demo is also available on the [`facebook/MusicGen` Hugging Face Space](https://huggingface.co./spaces/facebook/MusicGen) +(huge thanks to all the HF team for their support). +2. You can run the extended demo on a Colab: +[colab notebook](https://colab.research.google.com/drive/1JlTOjB-G0A2Hz3h8PK63vLZk4xdCI5QB?usp=sharing) +3. You can use the gradio demo locally by running [`python -m demos.musicgen_app --share`](../demos/musicgen_app.py). +4. You can play with MusicGen by running the jupyter notebook at [`demos/musicgen_demo.ipynb`](../demos/musicgen_demo.ipynb) locally (if you have a GPU). +5. Finally, checkout [@camenduru Colab page](https://github.com/camenduru/MusicGen-colab) +which is regularly updated with contributions from @camenduru and the community. + + +## API + +We provide a simple API and 10 pre-trained models. The pre trained models are: +- `facebook/musicgen-small`: 300M model, text to music only - [🤗 Hub](https://huggingface.co./facebook/musicgen-small) +- `facebook/musicgen-medium`: 1.5B model, text to music only - [🤗 Hub](https://huggingface.co./facebook/musicgen-medium) +- `facebook/musicgen-melody`: 1.5B model, text to music and text+melody to music - [🤗 Hub](https://huggingface.co./facebook/musicgen-melody) +- `facebook/musicgen-large`: 3.3B model, text to music only - [🤗 Hub](https://huggingface.co./facebook/musicgen-large) +- `facebook/musicgen-melody-large`: 3.3B model, text to music and text+melody to music - [🤗 Hub](https://huggingface.co./facebook/musicgen-melody-large) +- `facebook/musicgen-stereo-*`: All the previous models fine tuned for stereo generation - + [small](https://huggingface.co./facebook/musicgen-stereo-small), + [medium](https://huggingface.co./facebook/musicgen-stereo-medium), + [large](https://huggingface.co./facebook/musicgen-stereo-large), + [melody](https://huggingface.co./facebook/musicgen-stereo-melody), + [melody large](https://huggingface.co./facebook/musicgen-stereo-melody-large). + +We observe the best trade-off between quality and compute with the `facebook/musicgen-medium` or `facebook/musicgen-melody` model. +In order to use MusicGen locally **you must have a GPU**. We recommend 16GB of memory, but smaller +GPUs will be able to generate short sequences, or longer sequences with the `facebook/musicgen-small` model. + +See after a quick example for using the API. + +```python +import torchaudio +from audiocraft.models import MusicGen +from audiocraft.data.audio import audio_write + +model = MusicGen.get_pretrained('facebook/musicgen-melody') +model.set_generation_params(duration=8) # generate 8 seconds. +wav = model.generate_unconditional(4) # generates 4 unconditional audio samples +descriptions = ['happy rock', 'energetic EDM', 'sad jazz'] +wav = model.generate(descriptions) # generates 3 samples. + +melody, sr = torchaudio.load('./assets/bach.mp3') +# generates using the melody from the given audio and the provided descriptions. +wav = model.generate_with_chroma(descriptions, melody[None].expand(3, -1, -1), sr) + +for idx, one_wav in enumerate(wav): + # Will save under {idx}.wav, with loudness normalization at -14 db LUFS. + audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True) +``` + +## 🤗 Transformers Usage + +MusicGen is available in the 🤗 Transformers library from version 4.31.0 onwards, requiring minimal dependencies +and additional packages. Steps to get started: + +1. First install the 🤗 [Transformers library](https://github.com/huggingface/transformers) from main: + +```shell +pip install git+https://github.com/huggingface/transformers.git +``` + +2. Run the following Python code to generate text-conditional audio samples: + +```py +from transformers import AutoProcessor, MusicgenForConditionalGeneration + + +processor = AutoProcessor.from_pretrained("facebook/musicgen-small") +model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") + +inputs = processor( + text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"], + padding=True, + return_tensors="pt", +) + +audio_values = model.generate(**inputs, max_new_tokens=256) +``` + +3. Listen to the audio samples either in an ipynb notebook: + +```py +from IPython.display import Audio + +sampling_rate = model.config.audio_encoder.sampling_rate +Audio(audio_values[0].numpy(), rate=sampling_rate) +``` + +Or save them as a `.wav` file using a third-party library, e.g. `scipy`: + +```py +import scipy + +sampling_rate = model.config.audio_encoder.sampling_rate +scipy.io.wavfile.write("musicgen_out.wav", rate=sampling_rate, data=audio_values[0, 0].numpy()) +``` + +For more details on using the MusicGen model for inference using the 🤗 Transformers library, refer to the +[MusicGen docs](https://huggingface.co./docs/transformers/main/en/model_doc/musicgen) or the hands-on +[Google Colab](https://colab.research.google.com/github/sanchit-gandhi/notebooks/blob/main/MusicGen.ipynb). + + +## Training + +The [MusicGenSolver](../audiocraft/solvers/musicgen.py) implements MusicGen's training pipeline. +It defines an autoregressive language modeling task over multiple streams of discrete tokens +extracted from a pre-trained EnCodec model (see [EnCodec documentation](./ENCODEC.md) +for more details on how to train such model). + +Note that **we do NOT provide any of the datasets** used for training MusicGen. +We provide a dummy dataset containing just a few examples for illustrative purposes. + +Please read first the [TRAINING documentation](./TRAINING.md), in particular the Environment Setup section. + + +**Warning:** As of version 1.1.0, a few breaking changes were introduced. Check the [CHANGELOG.md](../CHANGELOG.md) +file for more information. You might need to retrain some of your models. + +### Example configurations and grids + +We provide configurations to reproduce the released models and our research. +MusicGen solvers configuration are available in [config/solver/musicgen](../config/solver/musicgen), +in particular: +* MusicGen base model for text-to-music: +[`solver=musicgen/musicgen_base_32khz`](../config/solver/musicgen/musicgen_base_32khz.yaml) +* MusicGen model with chromagram-conditioning support: +[`solver=musicgen/musicgen_melody_32khz`](../config/solver/musicgen/musicgen_melody_32khz.yaml) + +We provide 3 different scales, e.g. `model/lm/model_scale=small` (300M), or `medium` (1.5B), and `large` (3.3B). + +Please find some example grids to train MusicGen at +[audiocraft/grids/musicgen](../audiocraft/grids/musicgen/). + +```shell +# text-to-music +dora grid musicgen.musicgen_base_32khz --dry_run --init +# melody-guided music generation +dora grid musicgen.musicgen_melody_base_32khz --dry_run --init +# Remove the `--dry_run --init` flags to actually schedule the jobs once everything is setup. +``` + +### Music dataset and metadata + +MusicGen's underlying dataset is an AudioDataset augmented with music-specific metadata. +The MusicGen dataset implementation expects the metadata to be available as `.json` files +at the same location as the audio files. Learn more in the [datasets section](./DATASETS.md). + + +### Audio tokenizers + +We support a number of audio tokenizers: either pretrained EnCodec models, [DAC](https://github.com/descriptinc/descript-audio-codec), or your own models. +The tokenizer is controlled with the setting `compression_model_checkpoint`. +For instance, + +```bash +# Using the 32kHz EnCodec trained on music +dora run solver=musicgen/debug \ + compression_model_checkpoint=//pretrained/facebook/encodec_32khz \ + transformer_lm.n_q=4 transformer_lm.card=2048 + +# Using DAC +dora run solver=musicgen/debug \ + compression_model_checkpoint=//pretrained/dac_44khz \ + transformer_lm.n_q=9 transformer_lm.card=1024 \ + 'codebooks_pattern.delay.delays=[0,1,2,3,4,5,6,7,8]' + +# Using your own model after export (see ENCODEC.md) +dora run solver=musicgen/debug \ + compression_model_checkpoint=//pretrained//checkpoints/my_audio_lm/compression_state_dict.bin \ + transformer_lm.n_q=... transformer_lm.card=... + +# Using your own model from its training checkpoint. +dora run solver=musicgen/debug \ + compression_model_checkpoint=//sig/SIG \ # where SIG is the Dora signature of the EnCodec XP. + transformer_lm.n_q=... transformer_lm.card=... +``` + +**Warning:** you are responsible for setting the proper value for `transformer_lm.n_q` and `transformer_lm.card` (cardinality of the codebooks). You also have to update the codebook_pattern to match `n_q` as shown in the example for using DAC. . + + +### Training stereo models + +Use the option `interleave_stereo_codebooks.use` set to `True` to activate stereo training along with `channels=2`. Left and right channels will be +encoded separately by the compression model, then their codebook will be interleaved, e.g. order of codebook is +`[1_L, 1_R, 2_L, 2_R, ...]`. You will also need to update the delays for the codebook patterns to match the number of codebooks, and the `n_q` value passed to the transformer LM: +``` +dora run solver=musicgen/debug \ + compression_model_checkpoint=//pretrained/facebook/encodec_32khz \ + channels=2 interleave_stereo_codebooks.use=True \ + transformer_lm.n_q=8 transformer_lm.card=2048 \ + codebooks_pattern.delay.delays='[0, 0, 1, 1, 2, 2, 3, 3]' +``` + +### Fine tuning existing models + +You can initialize your model to one of the pretrained models by using the `continue_from` argument, in particular + +```bash +# Using pretrained MusicGen model. +dora run solver=musicgen/musicgen_base_32khz model/lm/model_scale=medium continue_from=//pretrained/facebook/musicgen-medium conditioner=text2music + +# Using another model you already trained with a Dora signature SIG. +dora run solver=musicgen/musicgen_base_32khz model/lm/model_scale=medium continue_from=//sig/SIG conditioner=text2music + +# Or providing manually a path +dora run solver=musicgen/musicgen_base_32khz model/lm/model_scale=medium continue_from=/checkpoints/my_other_xp/checkpoint.th +``` + +**Warning:** You are responsible for selecting the other parameters accordingly, in a way that make it compatible + with the model you are fine tuning. Configuration is NOT automatically inherited from the model you continue from. In particular make sure to select the proper `conditioner` and `model/lm/model_scale`. + +**Warning:** We currently do not support fine tuning a model with slightly different layers. If you decide + to change some parts, like the conditioning or some other parts of the model, you are responsible for manually crafting a checkpoint file from which we can safely run `load_state_dict`. + If you decide to do so, make sure your checkpoint is saved with `torch.save` and contains a dict + `{'best_state': {'model': model_state_dict_here}}`. Directly give the path to `continue_from` without a `//pretrained/` prefix. + + +#### Fine tuning mono model to stereo + +You will not be able to `continue_from` a mono model with stereo training, as the shape of the embeddings and output linears +would not match. You can use the following snippet to prepare a proper finetuning checkpoint. + +```python +from pathlib import Path +import torch + +# Download the pretrained model, e.g. from +# https://huggingface.co./facebook/musicgen-melody/blob/main/state_dict.bin + +model_name = 'musicgen-melody' +root = Path.home() / 'checkpoints' +# You are responsible for downloading the following checkpoint in the proper location +input_state_dict_path = root / model_name / 'state_dict.bin' +state = torch.load(input_state_dict_path, 'cpu') +bs = state['best_state'] +# there is a slight different in format between training checkpoints and exported public checkpoints. +# If you want to use your own mono models from one of your training checkpont, following the instructions +# for exporting a model explained later on this page. +assert 'model' not in bs, 'The following code is for using an exported pretrained model' +nbs = dict(bs) +for k in range(8): + # We will just copy mono embeddings and linears twice, once for left and right channels. + nbs[f'linears.{k}.weight'] = bs[f'linears.{k//2}.weight'] + nbs[f'emb.{k}.weight'] = bs[f'emb.{k//2}.weight'] +torch.save({'best_state': {'model': nbs}}, root / f'stereo_finetune_{model_name}.th') +``` + +Now, you can use `$HOME/checkpoints/stereo_finetune_musicgen-melody.th` as a `continue_from` target (without a `//pretrained` prefix!). + +### Caching of EnCodec tokens + +It is possible to precompute the EnCodec tokens and other metadata. +An example of generating and using this cache provided in the [musicgen.musicgen_base_cached_32khz grid](../audiocraft/grids/musicgen/musicgen_base_cached_32khz.py). + +### Evaluation stage + +By default, evaluation stage is also computing the cross-entropy and the perplexity over the +evaluation dataset. Indeed the objective metrics used for evaluation can be costly to run +or require some extra dependencies. Please refer to the [metrics documentation](./METRICS.md) +for more details on the requirements for each metric. + +We provide an off-the-shelf configuration to enable running the objective metrics +for audio generation in +[config/solver/musicgen/evaluation/objective_eval](../config/solver/musicgen/evaluation/objective_eval.yaml). + +One can then activate evaluation the following way: +```shell +# using the configuration +dora run solver=musicgen/debug solver/musicgen/evaluation=objective_eval +# specifying each of the fields, e.g. to activate KL computation +dora run solver=musicgen/debug evaluate.metrics.kld=true +``` + +See [an example evaluation grid](../audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py). + +### Generation stage + +The generation stage allows to generate samples conditionally and/or unconditionally and to perform +audio continuation (from a prompt). We currently support greedy sampling (argmax), sampling +from softmax with a given temperature, top-K and top-P (nucleus) sampling. The number of samples +generated and the batch size used are controlled by the `dataset.generate` configuration +while the other generation parameters are defined in `generate.lm`. + +```shell +# control sampling parameters +dora run solver=musicgen/debug generate.lm.gen_duration=10 generate.lm.use_sampling=true generate.lm.top_k=15 +``` + +#### Listening to samples + +Note that generation happens automatically every 25 epochs. You can easily access and +compare samples between models (as long as they are trained) on the same dataset using the +MOS tool. For that first `pip install Flask gunicorn`. Then +``` +gunicorn -w 4 -b 127.0.0.1:8895 -t 120 'scripts.mos:app' --access-logfile - +``` +And access the tool at [https://127.0.0.1:8895](https://127.0.0.1:8895). + +### Playing with the model + +Once you have launched some experiments, you can easily get access +to the Solver with the latest trained model using the following snippet. + +```python +from audiocraft.solvers.musicgen import MusicGen + +solver = MusicGen.get_eval_solver_from_sig('SIG', device='cpu', batch_size=8) +solver.model +solver.dataloaders +``` + +### Importing / Exporting models + +We do not support currently loading a model from the Hugging Face implementation or exporting to it. +If you want to export your model in a way that is compatible with `audiocraft.models.MusicGen` +API, you can run: + +```python +from audiocraft.utils import export +from audiocraft import train +xp = train.main.get_xp_from_sig('SIG_OF_LM') +export.export_lm(xp.folder / 'checkpoint.th', '/checkpoints/my_audio_lm/state_dict.bin') +# You also need to bundle the EnCodec model you used !! +## Case 1) you trained your own +xp_encodec = train.main.get_xp_from_sig('SIG_OF_ENCODEC') +export.export_encodec(xp_encodec.folder / 'checkpoint.th', '/checkpoints/my_audio_lm/compression_state_dict.bin') +## Case 2) you used a pretrained model. Give the name you used without the //pretrained/ prefix. +## This will actually not dump the actual model, simply a pointer to the right model to download. +export.export_pretrained_compression_model('facebook/encodec_32khz', '/checkpoints/my_audio_lm/compression_state_dict.bin') +``` + +Now you can load your custom model with: +```python +import audiocraft.models +musicgen = audiocraft.models.MusicGen.get_pretrained('/checkpoints/my_audio_lm/') +``` + + +### Learn more + +Learn more about AudioCraft training pipelines in the [dedicated section](./TRAINING.md). + +## FAQ + +#### I need help on Windows + +@FurkanGozukara made a complete tutorial for [AudioCraft/MusicGen on Windows](https://youtu.be/v-YpvPkhdO4) + +#### I need help for running the demo on Colab + +Check [@camenduru tutorial on YouTube](https://www.youtube.com/watch?v=EGfxuTy9Eeo). + +#### What are top-k, top-p, temperature and classifier-free guidance? + +Check out [@FurkanGozukara tutorial](https://github.com/FurkanGozukara/Stable-Diffusion/blob/main/Tutorials/AI-Music-Generation-Audiocraft-Tutorial.md#more-info-about-top-k-top-p-temperature-and-classifier-free-guidance-from-chatgpt). + +#### Should I use FSDP or autocast ? + +The two are mutually exclusive (because FSDP does autocast on its own). +You can use autocast up to 1.5B (medium), if you have enough RAM on your GPU. +FSDP makes everything more complex but will free up some memory for the actual +activations by sharding the optimizer state. + +## Citation +``` +@article{copet2023simple, + title={Simple and Controllable Music Generation}, + author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez}, + year={2023}, + journal={arXiv preprint arXiv:2306.05284}, +} +``` + + +## License + +See license information in the [model card](../model_cards/MUSICGEN_MODEL_CARD.md). + + +[arxiv]: https://arxiv.org/abs/2306.05284 +[musicgen_samples]: https://ai.honu.io/papers/musicgen/ diff --git a/docs/TRAINING.md b/docs/TRAINING.md new file mode 100644 index 0000000000000000000000000000000000000000..148de295f2ddfed2e4e893576bf31e1485038b8e --- /dev/null +++ b/docs/TRAINING.md @@ -0,0 +1,312 @@ +# AudioCraft training pipelines + +AudioCraft training pipelines are built on top of PyTorch as our core deep learning library +and [Flashy](https://github.com/facebookresearch/flashy) as our training pipeline design library, +and [Dora](https://github.com/facebookresearch/dora) as our experiment manager. +AudioCraft training pipelines are designed to be research and experiment-friendly. + + +## Environment setup + +For the base installation, follow the instructions from the [README.md](../README.md). +Below are some additional instructions for setting up environment to train new models. + +### Team and cluster configuration + +In order to support multiple teams and clusters, AudioCraft uses an environment configuration. +The team configuration allows to specify cluster-specific configurations (e.g. SLURM configuration), +or convenient mapping of paths between the supported environments. + +Each team can have a yaml file under the [configuration folder](../config). To select a team set the +`AUDIOCRAFT_TEAM` environment variable to a valid team name (e.g. `labs` or `default`): +```shell +conda env config vars set AUDIOCRAFT_TEAM=default +``` + +Alternatively, you can add it to your `.bashrc`: +```shell +export AUDIOCRAFT_TEAM=default +``` + +If not defined, the environment will default to the `default` team. + +The cluster is automatically detected, but it is also possible to override it by setting +the `AUDIOCRAFT_CLUSTER` environment variable. + +Based on this team and cluster, the environment is then configured with: +* The dora experiment outputs directory. +* The available slurm partitions: categorized by global and team. +* A shared reference directory: In order to facilitate sharing research models while remaining +agnostic to the used compute cluster, we created the `//reference` symbol that can be used in +YAML config to point to a defined reference folder containing shared checkpoints +(e.g. baselines, models for evaluation...). + +**Important:** The default output dir for trained models and checkpoints is under `/tmp/`. This is suitable +only for quick testing. If you are doing anything serious you MUST edit the file `default.yaml` and +properly set the `dora_dir` entries. + +#### Overriding environment configurations + +You can set the following environmet variables to bypass the team's environment configuration: +* `AUDIOCRAFT_CONFIG`: absolute path to a team config yaml file. +* `AUDIOCRAFT_DORA_DIR`: absolute path to a custom dora directory. +* `AUDIOCRAFT_REFERENCE_DIR`: absolute path to the shared reference directory. + +## Training pipelines + +Each task supported in AudioCraft has its own training pipeline and dedicated solver. +Learn more about solvers and key designs around AudioCraft training pipeline below. +Please refer to the documentation of each task and model for specific information on a given task. + + +### Solvers + +The core training component in AudioCraft is the solver. A solver holds the definition +of how to solve a given task: It implements the training pipeline logic, combining the datasets, +model, optimization criterion and components and the full training loop. We refer the reader +to [Flashy](https://github.com/facebookresearch/flashy) for core principles around solvers. + +AudioCraft proposes an initial solver, the `StandardSolver` that is used as the base implementation +for downstream solvers. This standard solver provides a nice base management of logging, +checkpoints loading/saving, xp restoration, etc. on top of the base Flashy implementation. +In AudioCraft, we made the assumption that all tasks are following the same set of stages: +train, valid, evaluate and generation, each relying on a dedicated dataset. + +Each solver is responsible for defining the task to solve and the associated stages +of the training loop in order to leave the full ownership of the training pipeline +to the researchers. This includes loading the datasets, building the model and +optimisation components, registering them and defining the execution of each stage. +To create a new solver for a given task, one should extend the StandardSolver +and define each stage of the training loop. One can further customise its own solver +starting from scratch instead of inheriting from the standard solver. + +```python +from . import base +from .. import optim + + +class MyNewSolver(base.StandardSolver): + + def __init__(self, cfg: omegaconf.DictConfig): + super().__init__(cfg) + # one can add custom attributes to the solver + self.criterion = torch.nn.L1Loss() + + def best_metric(self): + # here optionally specify which metric to use to keep track of best state + return 'loss' + + def build_model(self): + # here you can instantiate your models and optimization related objects + # this method will be called by the StandardSolver init method + self.model = ... + # the self.cfg attribute contains the raw configuration + self.optimizer = optim.build_optimizer(self.model.parameters(), self.cfg.optim) + # don't forget to register the states you'd like to include in your checkpoints! + self.register_stateful('model', 'optimizer') + # keep the model best state based on the best value achieved at validation for the given best_metric + self.register_best('model') + # if you want to add EMA around the model + self.register_ema('model') + + def build_dataloaders(self): + # here you can instantiate your dataloaders + # this method will be called by the StandardSolver init method + self.dataloaders = ... + + ... + + # For both train and valid stages, the StandardSolver relies on + # a share common_train_valid implementation that is in charge of + # accessing the appropriate loader, iterate over the data up to + # the specified number of updates_per_epoch, run the ``run_step`` + # function that you need to implement to specify the behavior + # and finally update the EMA and collect the metrics properly. + @abstractmethod + def run_step(self, idx: int, batch: tp.Any, metrics: dict): + """Perform one training or valid step on a given batch. + """ + ... # provide your implementation of the solver over a batch + + def train(self): + """Train stage. + """ + return self.common_train_valid('train') + + def valid(self): + """Valid stage. + """ + return self.common_train_valid('valid') + + @abstractmethod + def evaluate(self): + """Evaluate stage. + """ + ... # provide your implementation here! + + @abstractmethod + def generate(self): + """Generate stage. + """ + ... # provide your implementation here! +``` + +### About Epochs + +AudioCraft Solvers uses the concept of Epoch. One epoch doesn't necessarily mean one pass over the entire +dataset, but instead represent the smallest amount of computation that we want to work with before checkpointing. +Typically, we find that having an Epoch time around 30min is ideal both in terms of safety (checkpointing often enough) +and getting updates often enough. One Epoch is at least a `train` stage that lasts for `optim.updates_per_epoch` (2000 by default), +and a `valid` stage. You can control how long the valid stage takes with `dataset.valid.num_samples`. +Other stages (`evaluate`, `generate`) will only happen every X epochs, as given by `evaluate.every` and `generate.every`). + + +### Models + +In AudioCraft, a model is a container object that wraps one or more torch modules together +with potential processing logic to use in a solver. For example, a model would wrap an encoder module, +a quantisation bottleneck module, a decoder and some tensor processing logic. Each of the previous components +can be considered as a small « model unit » on its own but the container model is a practical component +to manipulate and train a set of modules together. + +### Datasets + +See the [dedicated documentation on datasets](./DATASETS.md). + +### Metrics + +See the [dedicated documentation on metrics](./METRICS.md). + +### Conditioners + +AudioCraft language models can be conditioned in various ways and the codebase offers a modular implementation +of different conditioners that can be potentially combined together. +Learn more in the [dedicated documentation on conditioning](./CONDITIONING.md). + +### Configuration + +AudioCraft's configuration is defined in yaml files and the framework relies on +[hydra](https://hydra.cc/docs/intro/) and [omegaconf](https://omegaconf.readthedocs.io/) to parse +and manipulate the configuration through Dora. + +##### :warning: Important considerations around configurations + +Our configuration management relies on Hydra and the concept of group configs to structure +and compose configurations. Updating the root default configuration files will then have +an impact on all solvers and tasks. +**One should never change the default configuration files. Instead they should use Hydra config groups in order to store custom configuration.** +Once this configuration is created and used for running experiments, you should not edit it anymore. + +Note that as we are using Dora as our experiment manager, all our experiment tracking is based on +signatures computed from delta between configurations. +**One must therefore ensure backward compatibilty of the configuration at all time.** +See [Dora's README](https://github.com/facebookresearch/dora) and the +[section below introduction Dora](#running-experiments-with-dora). + +##### Configuration structure + +The configuration is organized in config groups: +* `conditioner`: default values for conditioning modules. +* `dset`: contains all data source related information (paths to manifest files +and metadata for a given dataset). +* `model`: contains configuration for each model defined in AudioCraft and configurations +for different variants of models. +* `solver`: contains the default configuration for each solver as well as configuration +for each solver task, combining all the above components. +* `teams`: contains the cluster configuration per teams. See environment setup for more details. + +The `config.yaml` file is the main configuration that composes the above groups +and contains default configuration for AudioCraft. + +##### Solver's core configuration structure + +The core configuration structure shared across solver is available in `solvers/default.yaml`. + +##### Other configuration modules + +AudioCraft configuration contains the different setups we used for our research and publications. + +## Running experiments with Dora + +### Launching jobs + +Try launching jobs for different tasks locally with dora run: + +```shell +# run compression task with lightweight encodec +dora run solver=compression/debug +``` + +Most of the time, the jobs are launched through dora grids, for example: + +```shell +# run compression task through debug grid +dora grid compression.debug +``` + +Learn more about running experiments with Dora below. + +### A small introduction to Dora + +[Dora](https://github.com/facebookresearch/dora) is the experiment manager tool used in AudioCraft. +Check out the README to learn how Dora works. Here is a quick summary of what to know: +* An XP is a unique set of hyper-parameters with a given signature. The signature is a hash +of those hyper-parameters. We always refer to an XP with its signature, e.g. 9357e12e. We will see +after that one can retrieve the hyper-params and re-rerun it in a single command. +* In fact, the hash is defined as a delta between the base config and the one obtained +with the config overrides you passed from the command line. This means you must never change +the `conf/**.yaml` files directly., except for editing things like paths. Changing the default values +in the config files means the XP signature won't reflect that change, and wrong checkpoints might be reused. +I know, this is annoying, but the reason is that otherwise, any change to the config file would mean +that all XPs ran so far would see their signature change. + +#### Dora commands + +```shell +dora info -f 81de367c # this will show the hyper-parameter used by a specific XP. + # Be careful some overrides might present twice, and the right most one + # will give you the right value for it. + +dora run -d -f 81de367c # run an XP with the hyper-parameters from XP 81de367c. + # `-d` is for distributed, it will use all available GPUs. + +dora run -d -f 81de367c dataset.batch_size=32 # start from the config of XP 81de367c but change some hyper-params. + # This will give you a new XP with a new signature (e.g. 3fe9c332). + +dora info -f SIG -t # will tail the log (if the XP has scheduled). +# if you need to access the logs of the process for rank > 0, in particular because a crash didn't happen in the main +# process, then use `dora info -f SIG` to get the main log name (finished into something like `/5037674_0_0_log.out`) +# and worker K can accessed as `/5037674_0_{K}_log.out`. +# This is only for scheduled jobs, for local distributed runs with `-d`, then you should go into the XP folder, +# and look for `worker_{K}.log` logs. +``` + +An XP runs from a specific folder based on its signature, under the +`//experiments/audiocraft/outputs/` folder. +You can safely interrupt a training and resume it, it will reuse any existing checkpoint, +as it will reuse the same folder. If you made some change to the code and need to ignore +a previous checkpoint you can use `dora run --clear [RUN ARGS]`. + +If you have a Slurm cluster, you can also use the dora grid command, e.g. + +```shell +# run a dummy grid located at `audiocraft/grids/my_grid_folder/my_grid_name.py` +dora grid my_grid_folder.my_grid_name +# Run the following will simply display the grid and also initialized the Dora experiments database. +# You can then simply refer to a config using its signature (e.g. as `dora run -f SIG`). +dora grid my_grid_folder.my_grid_name --dry_run --init +``` + +Please refer to the [Dora documentation](https://github.com/facebookresearch/dora) for more information. + + +#### Clearing up past experiments + +```shell +# This will cancel all the XPs and delete their folder and checkpoints. +# It will then reschedule them starting from scratch. +dora grid my_grid_folder.my_grid_name --clear +# The following will delete the folder and checkpoint for a single XP, +# and then run it afresh. +dora run [-f BASE_SIG] [ARGS] --clear +``` diff --git a/egs/example/data.jsonl b/egs/example/data.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..63c3c333daa3418f52f952f9d018ccedee017899 --- /dev/null +++ b/egs/example/data.jsonl @@ -0,0 +1,2 @@ +{"path": "dataset/example/electro_1.mp3", "duration": 15.024, "sample_rate": 48000, "amplitude": null, "weight": null, "info_path": null} +{"path": "dataset/example/electro_2.mp3", "duration": 20.035918367346937, "sample_rate": 44100, "amplitude": null, "weight": null, "info_path": null} diff --git a/model_cards/AUDIOGEN_MODEL_CARD.md b/model_cards/AUDIOGEN_MODEL_CARD.md new file mode 100644 index 0000000000000000000000000000000000000000..5dcd23d8276d8f474043976672ea249d8b2a9dd1 --- /dev/null +++ b/model_cards/AUDIOGEN_MODEL_CARD.md @@ -0,0 +1,79 @@ +# AudioGen Model Card + +## Model details +**Organization developing the model:** The FAIR team of Meta AI. + +**Model date:** This version of AudioGen was trained between July 2023 and August 2023. + +**Model version:** This is version 2 of the model, not to be confused with the original AudioGen model published in ["AudioGen: Textually Guided Audio Generation"][audiogen]. +In this version (v2), AudioGen was trained on the same data, but with some other differences: +1. This model was trained on 10 seconds (vs. 5 seconds in v1). +2. The discrete representation used under the hood is extracted using a retrained EnCodec model on the environmental sound data, following the EnCodec setup detailed in the ["Simple and Controllable Music Generation" paper][musicgen]. +3. No audio mixing augmentations. + +**Model type:** AudioGen consists of an EnCodec model for audio tokenization, and an auto-regressive language model based on the transformer architecture for audio modeling. The released model has 1.5B parameters. + +**Paper or resource for more information:** More information can be found in the paper [AudioGen: Textually Guided Audio Generation](https://arxiv.org/abs/2209.15352). + +**Citation details:** See [AudioGen paper][audiogen] + +**License:** Code is released under MIT, model weights are released under CC-BY-NC 4.0. + +**Where to send questions or comments about the model:** Questions and comments about AudioGen can be sent via the [GitHub repository](https://github.com/facebookresearch/audiocraft) of the project, or by opening an issue. + +## Intended use +**Primary intended use:** The primary use of AudioGen is research on AI-based audio generation, including: +- Research efforts, such as probing and better understanding the limitations of generative models to further improve the state of science +- Generation of sound guided by text to understand current abilities of generative AI models by machine learning amateurs + +**Primary intended users:** The primary intended users of the model are researchers in audio, machine learning and artificial intelligence, as well as amateur seeking to better understand those models. + +**Out-of-scope use cases** The model should not be used on downstream applications without further risk evaluation and mitigation. The model should not be used to intentionally create or disseminate audio pieces that create hostile or alienating environments for people. This includes generating audio that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes. + +## Metrics + +**Models performance measures:** We used the following objective measure to evaluate the model on a standard audio benchmark: +- Frechet Audio Distance computed on features extracted from a pre-trained audio classifier (VGGish) +- Kullback-Leibler Divergence on label distributions extracted from a pre-trained audio classifier (PaSST) + +Additionally, we run qualitative studies with human participants, evaluating the performance of the model with the following axes: +- Overall quality of the audio samples; +- Text relevance to the provided text input; + +More details on performance measures and human studies can be found in the paper. + +**Decision thresholds:** Not applicable. + +## Evaluation datasets + +The model was evaluated on the [AudioCaps benchmark](https://audiocaps.github.io/). + +## Training datasets + +The model was trained on the following data sources: a subset of AudioSet (Gemmeke et al., 2017), [BBC sound effects](https://sound-effects.bbcrewind.co.uk/), AudioCaps (Kim et al., 2019), Clotho v2 (Drossos et al., 2020), VGG-Sound (Chen et al., 2020), FSD50K (Fonseca et al., 2021), [Free To Use Sounds](https://www.freetousesounds.com/all-in-one-bundle/), [Sonniss Game Effects](https://sonniss.com/gameaudiogdc), [WeSoundEffects](https://wesoundeffects.com/we-sound-effects-bundle-2020/), [Paramount Motion - Odeon Cinematic Sound Effects](https://www.paramountmotion.com/odeon-sound-effects). + +## Evaluation results + +Below are the objective metrics obtained with the released model on AudioCaps (consisting of 10-second long samples). Note that the model differs from the original AudioGen model introduced in the paper, hence the difference in the metrics. + +| Model | Frechet Audio Distance | KLD | Text consistency | +|---|---|---|---| +| facebook/audiogen-medium | 1.77 | 1.58 | 0.30 | + +More information can be found in the paper [AudioGen: Textually Guided Audio Generation][audiogen], in the Experiments section. + +## Limitations and biases + +**Limitations:** +- The model is not able to generate realistic vocals. +- The model has been trained with English descriptions and will not perform as well in other languages. +- It is sometimes difficult to assess what types of text descriptions provide the best generations. Prompt engineering may be required to obtain satisfying results. + +**Biases:** The datasets used for training may be lacking of diversity and are not representative of all possible sound events. The generated samples from the model will reflect the biases from the training data. + +**Risks and harms:** Biases and limitations of the model may lead to generation of samples that may be considered as biased, inappropriate or offensive. We believe that providing the code to reproduce the research and train new models will allow to broaden the application to new and more representative data. + +**Use cases:** Users must be aware of the biases, limitations and risks of the model. AudioGen is a model developed for artificial intelligence research on audio generation. As such, it should not be used for downstream applications without further investigation and mitigation of risks. + +[musicgen]: https://arxiv.org/abs/2306.05284 +[audiogen]: https://arxiv.org/abs/2209.15352 diff --git a/MODEL_CARD.md b/model_cards/MUSICGEN_MODEL_CARD.md similarity index 67% rename from MODEL_CARD.md rename to model_cards/MUSICGEN_MODEL_CARD.md index 6c2c9f883969eb905e74ad3376966d156cc5ca00..68e81d4467008d597f1e17105b37adff78c8218c 100644 --- a/MODEL_CARD.md +++ b/model_cards/MUSICGEN_MODEL_CARD.md @@ -12,11 +12,11 @@ **Paper or resources for more information:** More information can be found in the paper [Simple and Controllable Music Generation][arxiv]. -**Citation details** See [our paper][arxiv] +**Citation details:** See [our paper][arxiv] -**License** Code is released under MIT, model weights are released under CC-BY-NC 4.0. +**License:** Code is released under MIT, model weights are released under CC-BY-NC 4.0. -**Where to send questions or comments about the model:** Questions and comments about MusicGen can be sent via the [Github repository](https://github.com/facebookresearch/audiocraft) of the project, or by opening an issue. +**Where to send questions or comments about the model:** Questions and comments about MusicGen can be sent via the [GitHub repository](https://github.com/facebookresearch/audiocraft) of the project, or by opening an issue. ## Intended use **Primary intended use:** The primary use of MusicGen is research on AI-based music generation, including: @@ -26,7 +26,7 @@ **Primary intended users:** The primary intended users of the model are researchers in audio, machine learning and artificial intelligence, as well as amateur seeking to better understand those models. -**Out-of-scope use cases** The model should not be used on downstream applications without further risk evaluation and mitigation. The model should not be used to intentionally create or disseminate music pieces that create hostile or alienating environments for people. This includes generating music that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes. +**Out-of-scope use cases:** The model should not be used on downstream applications without further risk evaluation and mitigation. The model should not be used to intentionally create or disseminate music pieces that create hostile or alienating environments for people. This includes generating music that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes. ## Metrics @@ -54,15 +54,24 @@ The model was evaluated on the [MusicCaps benchmark](https://www.kaggle.com/data The model was trained on licensed data using the following sources: the [Meta Music Initiative Sound Collection](https://www.fb.com/sound), [Shutterstock music collection](https://www.shutterstock.com/music) and the [Pond5 music collection](https://www.pond5.com/). See the paper for more details about the training set and corresponding preprocessing. -## Quantitative analysis +## Evaluation results -More information can be found in the paper [Simple and Controllable Music Generation][arxiv], in the Experimental Setup section. +Below are the objective metrics obtained on MusicCaps with the released model. Note that for the publicly released models, we had all the datasets go through a state-of-the-art music source separation method, namely using the open source [Hybrid Transformer for Music Source Separation](https://github.com/facebookresearch/demucs) (HT-Demucs), in order to keep only the instrumental part. This explains the difference in objective metrics with the models used in the paper. + +| Model | Frechet Audio Distance | KLD | Text Consistency | Chroma Cosine Similarity | +|---|---|---|---|---| +| facebook/musicgen-small | 4.88 | 1.42 | 0.27 | - | +| facebook/musicgen-medium | 5.14 | 1.38 | 0.28 | - | +| facebook/musicgen-large | 5.48 | 1.37 | 0.28 | - | +| facebook/musicgen-melody | 4.93 | 1.41 | 0.27 | 0.44 | + +More information can be found in the paper [Simple and Controllable Music Generation][arxiv], in the Results section. ## Limitations and biases **Data:** The data sources used to train the model are created by music professionals and covered by legal agreements with the right holders. The model is trained on 20K hours of data, we believe that scaling the model on larger datasets can further improve the performance of the model. -**Mitigations:** Vocals have been removed from the data source using corresponding tags, and then using using a state-of-the-art music source separation method, namely using the open source [Hybrid Transformer for Music Source Separation](https://github.com/facebookresearch/demucs) (HT-Demucs). +**Mitigations:** Vocals have been removed from the data source using corresponding tags, and then using a state-of-the-art music source separation method, namely using the open source [Hybrid Transformer for Music Source Separation](https://github.com/facebookresearch/demucs) (HT-Demucs). **Limitations:** @@ -78,4 +87,19 @@ More information can be found in the paper [Simple and Controllable Music Genera **Use cases:** Users must be aware of the biases, limitations and risks of the model. MusicGen is a model developed for artificial intelligence research on controllable music generation. As such, it should not be used for downstream applications without further investigation and mitigation of risks. +## Update: stereo models and large melody. + +We further release a set of stereophonic capable models. Those were fine tuned for 200k updates starting +from the mono models. The training data is otherwise identical and capabilities and limitations are shared with the base modes. The stereo models work by getting 2 streams of tokens from the EnCodec model, and interleaving those using +the delay pattern. We also release a mono large model with melody conditioning capabilities. The list of new models +is as follow: + +- facebook/musicgen-stereo-small +- facebook/musicgen-stereo-medium +- facebook/musicgen-stereo-large +- facebook/musicgen-stereo-melody +- facebook/musicgen-melody-large +- facebook/musicgen-stereo-melody-large + + [arxiv]: https://arxiv.org/abs/2306.05284 diff --git a/mypy.ini b/mypy.ini index b2b9b2ec5055087b0a5d568d5607c7a4061b39e3..6ab60f2fd7545c803fca221614704a075b8f2188 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,4 +1,4 @@ [mypy] -[mypy-treetable,torchaudio.*,soundfile,einops.*,av.*,tqdm.*,num2words.*,spacy,xformers.*,scipy,huggingface_hub] +[mypy-treetable,torchaudio.*,soundfile,einops.*,av.*,tqdm.*,num2words.*,spacy,xformers.*,scipy,huggingface_hub,transformers,dac.*] ignore_missing_imports = True diff --git a/requirements.txt b/requirements.txt index 8cb86eb8c7eadddae38c2e545ef6aeb1f30cbd64..febb63e0a403b615eadcc29e5542ff70e970710f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,8 +13,11 @@ torch<2.1.0 torchaudio<2.1.0 huggingface_hub tqdm -transformers +transformers>=4.31.0 # need Encodec there. xformers demucs librosa gradio_client==0.2.6 +torchmetrics +encodec +protobuf diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa --- /dev/null +++ b/scripts/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/scripts/mos.py b/scripts/mos.py new file mode 100644 index 0000000000000000000000000000000000000000..a711c9ece23e72ed3a07032c7834ef7c56ab4f11 --- /dev/null +++ b/scripts/mos.py @@ -0,0 +1,286 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +""" +To run this script, from the root of the repo. Make sure to have Flask installed + + FLASK_DEBUG=1 FLASK_APP=scripts.mos flask run -p 4567 + # or if you have gunicorn + gunicorn -w 4 -b 127.0.0.1:8895 -t 120 'scripts.mos:app' --access-logfile - + +""" +from collections import defaultdict +from functools import wraps +from hashlib import sha1 +import json +import math +from pathlib import Path +import random +import typing as tp + +from flask import Flask, redirect, render_template, request, session, url_for + +from audiocraft import train +from audiocraft.utils.samples.manager import get_samples_for_xps + + +SAMPLES_PER_PAGE = 8 +MAX_RATING = 5 +storage = Path(train.main.dora.dir / 'mos_storage') +storage.mkdir(exist_ok=True) +surveys = storage / 'surveys' +surveys.mkdir(exist_ok=True) +magma_root = Path(train.__file__).parent.parent +app = Flask('mos', static_folder=str(magma_root / 'scripts/static'), + template_folder=str(magma_root / 'scripts/templates')) +app.secret_key = b'audiocraft makes the best songs' + + +def normalize_path(path: Path): + """Just to make path a bit nicer, make them relative to the Dora root dir. + """ + path = path.resolve() + dora_dir = train.main.dora.dir.resolve() / 'xps' + return path.relative_to(dora_dir) + + +def get_full_path(normalized_path: Path): + """Revert `normalize_path`. + """ + return train.main.dora.dir.resolve() / 'xps' / normalized_path + + +def get_signature(xps: tp.List[str]): + """Return a signature for a list of XP signatures. + """ + return sha1(json.dumps(xps).encode()).hexdigest()[:10] + + +def ensure_logged(func): + """Ensure user is logged in. + """ + @wraps(func) + def _wrapped(*args, **kwargs): + user = session.get('user') + if user is None: + return redirect(url_for('login', redirect_to=request.url)) + return func(*args, **kwargs) + return _wrapped + + +@app.route('/login', methods=['GET', 'POST']) +def login(): + """Login user if not already, then redirect. + """ + user = session.get('user') + if user is None: + error = None + if request.method == 'POST': + user = request.form['user'] + if not user: + error = 'User cannot be empty' + if user is None or error: + return render_template('login.html', error=error) + assert user + session['user'] = user + redirect_to = request.args.get('redirect_to') + if redirect_to is None: + redirect_to = url_for('index') + return redirect(redirect_to) + + +@app.route('/', methods=['GET', 'POST']) +@ensure_logged +def index(): + """Offer to create a new study. + """ + errors = [] + if request.method == 'POST': + xps_or_grids = [part.strip() for part in request.form['xps'].split()] + xps = set() + for xp_or_grid in xps_or_grids: + xp_path = train.main.dora.dir / 'xps' / xp_or_grid + if xp_path.exists(): + xps.add(xp_or_grid) + continue + grid_path = train.main.dora.dir / 'grids' / xp_or_grid + if grid_path.exists(): + for child in grid_path.iterdir(): + if child.is_symlink(): + xps.add(child.name) + continue + errors.append(f'{xp_or_grid} is neither an XP nor a grid!') + assert xps or errors + blind = 'true' if request.form.get('blind') == 'on' else 'false' + xps = list(xps) + if not errors: + signature = get_signature(xps) + manifest = { + 'xps': xps, + } + survey_path = surveys / signature + survey_path.mkdir(exist_ok=True) + with open(survey_path / 'manifest.json', 'w') as f: + json.dump(manifest, f, indent=2) + return redirect(url_for('survey', blind=blind, signature=signature)) + return render_template('index.html', errors=errors) + + +@app.route('/survey/', methods=['GET', 'POST']) +@ensure_logged +def survey(signature): + success = request.args.get('success', False) + seed = int(request.args.get('seed', 4321)) + blind = request.args.get('blind', 'false') in ['true', 'on', 'True'] + exclude_prompted = request.args.get('exclude_prompted', 'false') in ['true', 'on', 'True'] + exclude_unprompted = request.args.get('exclude_unprompted', 'false') in ['true', 'on', 'True'] + max_epoch = int(request.args.get('max_epoch', '-1')) + survey_path = surveys / signature + assert survey_path.exists(), survey_path + + user = session['user'] + result_folder = survey_path / 'results' + result_folder.mkdir(exist_ok=True) + result_file = result_folder / f'{user}_{seed}.json' + + with open(survey_path / 'manifest.json') as f: + manifest = json.load(f) + + xps = [train.main.get_xp_from_sig(xp) for xp in manifest['xps']] + names, ref_name = train.main.get_names(xps) + + samples_kwargs = { + 'exclude_prompted': exclude_prompted, + 'exclude_unprompted': exclude_unprompted, + 'max_epoch': max_epoch, + } + matched_samples = get_samples_for_xps(xps, epoch=-1, **samples_kwargs) # fetch latest epoch + models_by_id = { + id: [{ + 'xp': xps[idx], + 'xp_name': names[idx], + 'model_id': f'{xps[idx].sig}-{sample.id}', + 'sample': sample, + 'is_prompted': sample.prompt is not None, + 'errors': [], + } for idx, sample in enumerate(samples)] + for id, samples in matched_samples.items() + } + experiments = [ + {'xp': xp, 'name': names[idx], 'epoch': list(matched_samples.values())[0][idx].epoch} + for idx, xp in enumerate(xps) + ] + + keys = list(matched_samples.keys()) + keys.sort() + rng = random.Random(seed) + rng.shuffle(keys) + model_ids = keys[:SAMPLES_PER_PAGE] + + if blind: + for key in model_ids: + rng.shuffle(models_by_id[key]) + + ok = True + if request.method == 'POST': + all_samples_results = [] + for id in model_ids: + models = models_by_id[id] + result = { + 'id': id, + 'is_prompted': models[0]['is_prompted'], + 'models': {} + } + all_samples_results.append(result) + for model in models: + rating = request.form[model['model_id']] + if rating: + rating = int(rating) + assert rating <= MAX_RATING and rating >= 1 + result['models'][model['xp'].sig] = rating + model['rating'] = rating + else: + ok = False + model['errors'].append('Please rate this model.') + if ok: + result = { + 'results': all_samples_results, + 'seed': seed, + 'user': user, + 'blind': blind, + 'exclude_prompted': exclude_prompted, + 'exclude_unprompted': exclude_unprompted, + } + print(result) + with open(result_file, 'w') as f: + json.dump(result, f) + seed = seed + 1 + return redirect(url_for( + 'survey', signature=signature, blind=blind, seed=seed, + exclude_prompted=exclude_prompted, exclude_unprompted=exclude_unprompted, + max_epoch=max_epoch, success=True)) + + ratings = list(range(1, MAX_RATING + 1)) + return render_template( + 'survey.html', ratings=ratings, blind=blind, seed=seed, signature=signature, success=success, + exclude_prompted=exclude_prompted, exclude_unprompted=exclude_unprompted, max_epoch=max_epoch, + experiments=experiments, models_by_id=models_by_id, model_ids=model_ids, errors=[], + ref_name=ref_name, already_filled=result_file.exists()) + + +@app.route('/audio/') +def audio(path: str): + full_path = Path('/') / path + assert full_path.suffix in [".mp3", ".wav"] + return full_path.read_bytes(), {'Content-Type': 'audio/mpeg'} + + +def mean(x): + return sum(x) / len(x) + + +def std(x): + m = mean(x) + return math.sqrt(sum((i - m)**2 for i in x) / len(x)) + + +@app.route('/results/') +@ensure_logged +def results(signature): + + survey_path = surveys / signature + assert survey_path.exists(), survey_path + result_folder = survey_path / 'results' + result_folder.mkdir(exist_ok=True) + + # ratings per model, then per user. + ratings_per_model = defaultdict(list) + users = [] + for result_file in result_folder.iterdir(): + if result_file.suffix != '.json': + continue + with open(result_file) as f: + results = json.load(f) + users.append(results['user']) + for result in results['results']: + for sig, rating in result['models'].items(): + ratings_per_model[sig].append(rating) + + fmt = '{:.2f}' + models = [] + for model in sorted(ratings_per_model.keys()): + ratings = ratings_per_model[model] + + models.append({ + 'sig': model, + 'samples': len(ratings), + 'mean_rating': fmt.format(mean(ratings)), + # the value 1.96 was probably chosen to achieve some + # confidence interval assuming gaussianity. + 'std_rating': fmt.format(1.96 * std(ratings) / len(ratings)**0.5), + }) + return render_template('results.html', signature=signature, models=models, users=users) diff --git a/scripts/resample_dataset.py b/scripts/resample_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..af5288712b8d2cde2d9814c747275e69f6e970c8 --- /dev/null +++ b/scripts/resample_dataset.py @@ -0,0 +1,207 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Resampling script. +""" +import argparse +from pathlib import Path +import shutil +import typing as tp + +import submitit +import tqdm + +from audiocraft.data.audio import audio_read, audio_write +from audiocraft.data.audio_dataset import load_audio_meta, find_audio_files +from audiocraft.data.audio_utils import convert_audio +from audiocraft.environment import AudioCraftEnvironment + + +def read_txt_files(path: tp.Union[str, Path]): + with open(args.files_path) as f: + lines = [line.rstrip() for line in f] + print(f"Read {len(lines)} in .txt") + lines = [line for line in lines if Path(line).suffix not in ['.json', '.txt', '.csv']] + print(f"Filtered and keep {len(lines)} from .txt") + return lines + + +def read_egs_files(path: tp.Union[str, Path]): + path = Path(path) + if path.is_dir(): + if (path / 'data.jsonl').exists(): + path = path / 'data.jsonl' + elif (path / 'data.jsonl.gz').exists(): + path = path / 'data.jsonl.gz' + else: + raise ValueError("Don't know where to read metadata from in the dir. " + "Expecting either a data.jsonl or data.jsonl.gz file but none found.") + meta = load_audio_meta(path) + return [m.path for m in meta] + + +def process_dataset(args, n_shards: int, node_index: int, task_index: tp.Optional[int] = None): + if task_index is None: + env = submitit.JobEnvironment() + task_index = env.global_rank + shard_index = node_index * args.tasks_per_node + task_index + + if args.files_path is None: + lines = [m.path for m in find_audio_files(args.root_path, resolve=False, progress=True, workers=8)] + else: + files_path = Path(args.files_path) + if files_path.suffix == '.txt': + print(f"Reading file list from .txt file: {args.files_path}") + lines = read_txt_files(args.files_path) + else: + print(f"Reading file list from egs: {args.files_path}") + lines = read_egs_files(args.files_path) + + total_files = len(lines) + print( + f"Total of {total_files} processed with {n_shards} shards. " + + f"Current idx = {shard_index} -> {total_files // n_shards} files to process" + ) + for idx, line in tqdm.tqdm(enumerate(lines)): + + # skip if not part of this shard + if idx % n_shards != shard_index: + continue + + path = str(AudioCraftEnvironment.apply_dataset_mappers(line)) + root_path = str(args.root_path) + if not root_path.endswith('/'): + root_path += '/' + assert path.startswith(str(root_path)), \ + f"Mismatch between path and provided root: {path} VS {root_path}" + + try: + metadata_path = Path(path).with_suffix('.json') + out_path = args.out_path / path[len(root_path):] + out_metadata_path = out_path.with_suffix('.json') + out_done_token = out_path.with_suffix('.done') + + # don't reprocess existing files + if out_done_token.exists(): + continue + + print(idx, out_path, path) + mix, sr = audio_read(path) + mix_channels = args.channels if args.channels is not None and args.channels > 0 else mix.size(0) + # enforce simple stereo + out_channels = mix_channels + if out_channels > 2: + print(f"Mix has more than two channels: {out_channels}, enforcing 2 channels") + out_channels = 2 + out_sr = args.sample_rate if args.sample_rate is not None else sr + out_wav = convert_audio(mix, sr, out_sr, out_channels) + audio_write(out_path.with_suffix(''), out_wav, sample_rate=out_sr, + format=args.format, normalize=False, strategy='clip') + if metadata_path.exists(): + shutil.copy(metadata_path, out_metadata_path) + else: + print(f"No metadata found at {str(metadata_path)}") + out_done_token.touch() + except Exception as e: + print(f"Error processing file line: {line}, {e}") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Resample dataset with SLURM.") + parser.add_argument( + "--log_root", + type=Path, + default=Path.home() / 'tmp' / 'resample_logs', + ) + parser.add_argument( + "--files_path", + type=Path, + help="List of files to process, either .txt (one file per line) or a jsonl[.gz].", + ) + parser.add_argument( + "--root_path", + type=Path, + required=True, + help="When rewriting paths, this will be the prefix to remove.", + ) + parser.add_argument( + "--out_path", + type=Path, + required=True, + help="When rewriting paths, `root_path` will be replaced by this.", + ) + parser.add_argument("--xp_name", type=str, default="shutterstock") + parser.add_argument( + "--nodes", + type=int, + default=4, + ) + parser.add_argument( + "--tasks_per_node", + type=int, + default=20, + ) + parser.add_argument( + "--cpus_per_task", + type=int, + default=4, + ) + parser.add_argument( + "--memory_gb", + type=int, + help="Memory in GB." + ) + parser.add_argument( + "--format", + type=str, + default="wav", + ) + parser.add_argument( + "--sample_rate", + type=int, + default=32000, + ) + parser.add_argument( + "--channels", + type=int, + ) + parser.add_argument( + "--partition", + default='learnfair', + ) + parser.add_argument("--qos") + parser.add_argument("--account") + parser.add_argument("--timeout", type=int, default=4320) + parser.add_argument('--debug', action='store_true', help='debug mode (local run)') + args = parser.parse_args() + n_shards = args.tasks_per_node * args.nodes + if args.files_path is None: + print("Warning: --files_path not provided, not recommended when processing more than 10k files.") + if args.debug: + print("Debugging mode") + process_dataset(args, n_shards=n_shards, node_index=0, task_index=0) + else: + + log_folder = Path(args.log_root) / args.xp_name / '%j' + print(f"Logging to: {log_folder}") + log_folder.parent.mkdir(parents=True, exist_ok=True) + executor = submitit.AutoExecutor(folder=str(log_folder)) + if args.qos: + executor.update_parameters(slurm_partition=args.partition, slurm_qos=args.qos, slurm_account=args.account) + else: + executor.update_parameters(slurm_partition=args.partition) + executor.update_parameters( + slurm_job_name=args.xp_name, timeout_min=args.timeout, + cpus_per_task=args.cpus_per_task, tasks_per_node=args.tasks_per_node, nodes=1) + if args.memory_gb: + executor.update_parameters(mem=f'{args.memory_gb}GB') + jobs = [] + with executor.batch(): + for node_index in range(args.nodes): + job = executor.submit(process_dataset, args, n_shards=n_shards, node_index=node_index) + jobs.append(job) + for job in jobs: + print(f"Waiting on job {job.job_id}") + job.results() diff --git a/scripts/static/style.css b/scripts/static/style.css new file mode 100644 index 0000000000000000000000000000000000000000..a0df7c63a0d2dd9a79f33f5d869ca31c9da87e8d --- /dev/null +++ b/scripts/static/style.css @@ -0,0 +1,113 @@ +body { + background-color: #fbfbfb; + margin: 0; +} + +select, input { + font-size: 1em; + max-width: 100%; +} + +.xp_name { + font-family: monospace; +} + +.simple_form { + background-color: #dddddd; + padding: 1em; + margin: 0.5em; +} + +textarea { + margin-top: 0.5em; + margin-bottom: 0.5em; +} + +.rating { + background-color: grey; + padding-top: 5px; + padding-bottom: 5px; + padding-left: 8px; + padding-right: 8px; + margin-right: 2px; + cursor:pointer; +} + +.rating_selected { + background-color: purple; +} + +.content { + font-family: sans-serif; + background-color: #f6f6f6; + padding: 40px; + margin: 0 auto; + max-width: 1000px; +} + +.track label { + padding-top: 10px; + padding-bottom: 10px; +} +.track { + padding: 15px; + margin: 5px; + background-color: #c8c8c8; +} + +.submit-big { + width:400px; + height:30px; + font-size: 20px; +} + +.error { + color: red; +} + +.ratings { + margin-left: 10px; +} + +.important { + font-weight: bold; +} + +.survey { + margin-bottom: 100px; +} + +.success { + color: #25901b; + font-weight: bold; +} +.warning { + color: #8a1f19; + font-weight: bold; +} +.track>section { + display: flex; + align-items: center; +} + +.prompt { + display: flex; + align-items: center; +} + +.track>section>div { + padding-left: 10px; +} + +audio { + max-width: 280px; + max-height: 40px; + margin-left: 10px; + margin-right: 10px; +} + +.special { + font-weight: bold; + color: #2c2c2c; +} + diff --git a/scripts/templates/base.html b/scripts/templates/base.html new file mode 100644 index 0000000000000000000000000000000000000000..f74668c19ecb83090a8a2d82c026bf417190ec6d --- /dev/null +++ b/scripts/templates/base.html @@ -0,0 +1,16 @@ + + + + {% block head %} + + + AudioCraft — MOS + {% endblock %} + + +
+

AudioCraft — MOS

+ {% block content %}{% endblock %} +
+ + diff --git a/scripts/templates/index.html b/scripts/templates/index.html new file mode 100644 index 0000000000000000000000000000000000000000..7bd3afe9d933271bb922c1a0a534dd6b86fe67bc --- /dev/null +++ b/scripts/templates/index.html @@ -0,0 +1,28 @@ +{% extends "base.html" %} +{% block content %} + +

+ Welcome {{session['user']}} to the internal MOS assistant for AudioCraft. + You can create custom surveys between your models, that you can + evaluate yourself, or with the help of your teammates, by simply + sharing a link! +

+ +{% for error in errors %} +

{{error}}

+{% endfor %} +
+
+
+ +
+
+ +
+ + + +{% endblock %} diff --git a/scripts/templates/login.html b/scripts/templates/login.html new file mode 100644 index 0000000000000000000000000000000000000000..dd89ac654bceca14a9dec7d1a7f8206d1425a7a1 --- /dev/null +++ b/scripts/templates/login.html @@ -0,0 +1,20 @@ +{% extends "base.html" %} +{% block content %} + +

+ You must identify yourself first! We use a highly secured protocol + where you just decide your username, and that's it. No password, no encryption, + just pure trust. +

+ +{% if error %} +

{{error}}

+{% endif %} + + + + + +{% endblock %} diff --git a/scripts/templates/results.html b/scripts/templates/results.html new file mode 100644 index 0000000000000000000000000000000000000000..8ddce59f0f617a836db75c8bc9768db7f9f17511 --- /dev/null +++ b/scripts/templates/results.html @@ -0,0 +1,17 @@ +{% extends "base.html" %} +{% block content %} + +

Results for survey #{{signature}}

+

Checkout the survey page for details on the models.

+

The following users voted: + {% for user in users %} + {{user}} + {% endfor %} + +{% for model in models %} +

{{model['sig']}} ({{model['samples']}} samples)

+

Ratings: {{model['mean_rating']}} ± {{model['std_rating']}}

+ +{% endfor %} + +{% endblock %} diff --git a/scripts/templates/survey.html b/scripts/templates/survey.html new file mode 100644 index 0000000000000000000000000000000000000000..785d1e61b7ac21619416ba70dd4719ff250f3f4b --- /dev/null +++ b/scripts/templates/survey.html @@ -0,0 +1,131 @@ +{% extends "base.html" %} +{% block content %} +

Survey #{{signature}}

+{% if success %} +

Your ratings have been saved! +You have been moved to the next random seed, if you want +to keep rating more samples.

+{% endif %} +{% if already_filled %} +

You already rated those samples in the past, + filling this form will override your previous ratings. +

+{% endif %} +

Welcome {{session['user']}} to the survey #{{signature}}. +Go to the result page to check the results. Go to the home page to start a new survey. +

+ +{% for error in errors %} +

{{error}}

+{% endfor %} + +{% if not blind %} +

Base config is: {{ref_name}}

+

The following experiments are compared:

+
    + {% for experiment in experiments %} +
  • {{experiment.xp.sig}} ({{experiment.epoch}} epochs): {{experiment.name}}
  • + {% endfor %} +
+{% else %} +

This is a blind experiment, the order of all XPs is shuffled with every sample.

+{% endif %} +

The current random seed is {{seed}}. You can change it with the following form, and also update blind/non blind. +

+ + + + + + + +
+ +

Samples

+
+
+{% for id in model_ids %} +
+

{{id}}

+ {% for model in models_by_id[id] %} + {% if loop.index == 1 and model.is_prompted %} +
+

Prompt is

+ +

Ground truth is

+ +
+ {% endif %} + {% for err in model['errors'] %} +

{{err}}

+ {% endfor %} +
+ {% if not blind %} +

{{model.xp.sig}}:

+ {% endif %} + +

Rating:

+
+ {% for rating in ratings %} + {{rating}} + {% endfor %} + +
+

+
+ {% endfor %} +
+
+{% endfor %} + + +
+ +{% endblock %} diff --git a/setup.cfg b/setup.cfg index dc7aa4bb991d928ae3f03e0d850c4bd699be866e..a00890009a88752714357210a73709a83b395849 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,3 +3,12 @@ max-line-length = 120 [flake8] max-line-length = 120 + +[coverage:report] +include = audiocraft/* +omit = + audiocraft/environment.py + audiocraft/solvers/* + audiocraft/utils/* + audiocraft/*/loaders.py + audiocraft/*/builders.py diff --git a/setup.py b/setup.py index 78a172b7c90003b689bde40b49cc8fe1fb8107d4..64e7d6fcb1092748f8151f6d3ed1767d3be1b34b 100644 --- a/setup.py +++ b/setup.py @@ -1,11 +1,8 @@ -""" - Copyright (c) Meta Platforms, Inc. and affiliates. - All rights reserved. - - This source code is licensed under the license found in the - LICENSE file in the root directory of this source tree. - -""" +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. from pathlib import Path @@ -13,11 +10,11 @@ from setuptools import setup, find_packages NAME = 'audiocraft' -DESCRIPTION = 'Audio research library for PyTorch' +DESCRIPTION = 'Audio generation research library for PyTorch' -URL = 'https://github.com/fairinternal/audiocraft' +URL = 'https://github.com/facebookresearch/audiocraft' AUTHOR = 'FAIR Speech & Audio' -EMAIL = 'defossez@meta.com' +EMAIL = 'defossez@meta.com, jadecopet@meta.com' REQUIRES_PYTHON = '>=3.8.0' for line in open('audiocraft/__init__.py'): diff --git a/tests/adversarial/__init__.py b/tests/adversarial/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa --- /dev/null +++ b/tests/adversarial/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/adversarial/test_discriminators.py b/tests/adversarial/test_discriminators.py new file mode 100644 index 0000000000000000000000000000000000000000..fad89a0ae4534dc7967b6ccda194b9fd1dedbffe --- /dev/null +++ b/tests/adversarial/test_discriminators.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import random + +import torch + +from audiocraft.adversarial.discriminators import ( + MultiPeriodDiscriminator, + MultiScaleDiscriminator, + MultiScaleSTFTDiscriminator +) + + +class TestMultiPeriodDiscriminator: + + def test_mpd_discriminator(self): + N, C, T = 2, 2, random.randrange(1, 100_000) + t0 = torch.randn(N, C, T) + periods = [1, 2, 3] + mpd = MultiPeriodDiscriminator(periods=periods, in_channels=C) + logits, fmaps = mpd(t0) + + assert len(logits) == len(periods) + assert len(fmaps) == len(periods) + assert all([logit.shape[0] == N and len(logit.shape) == 4 for logit in logits]) + assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap]) + + +class TestMultiScaleDiscriminator: + + def test_msd_discriminator(self): + N, C, T = 2, 2, random.randrange(1, 100_000) + t0 = torch.randn(N, C, T) + + scale_norms = ['weight_norm', 'weight_norm'] + msd = MultiScaleDiscriminator(scale_norms=scale_norms, in_channels=C) + logits, fmaps = msd(t0) + + assert len(logits) == len(scale_norms) + assert len(fmaps) == len(scale_norms) + assert all([logit.shape[0] == N and len(logit.shape) == 3 for logit in logits]) + assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap]) + + +class TestMultiScaleStftDiscriminator: + + def test_msstftd_discriminator(self): + N, C, T = 2, 2, random.randrange(1, 100_000) + t0 = torch.randn(N, C, T) + + n_filters = 4 + n_ffts = [128, 256, 64] + hop_lengths = [32, 64, 16] + win_lengths = [128, 256, 64] + + msstftd = MultiScaleSTFTDiscriminator(filters=n_filters, n_ffts=n_ffts, hop_lengths=hop_lengths, + win_lengths=win_lengths, in_channels=C) + logits, fmaps = msstftd(t0) + + assert len(logits) == len(n_ffts) + assert len(fmaps) == len(n_ffts) + assert all([logit.shape[0] == N and len(logit.shape) == 4 for logit in logits]) + assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap]) diff --git a/tests/adversarial/test_losses.py b/tests/adversarial/test_losses.py new file mode 100644 index 0000000000000000000000000000000000000000..0e30bc3a6dde00003e13c00f15e977e39425063c --- /dev/null +++ b/tests/adversarial/test_losses.py @@ -0,0 +1,159 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import random + +import torch + +from audiocraft.adversarial import ( + AdversarialLoss, + get_adv_criterion, + get_real_criterion, + get_fake_criterion, + FeatureMatchingLoss, + MultiScaleDiscriminator, +) + + +class TestAdversarialLoss: + + def test_adversarial_single_multidiscriminator(self): + adv = MultiScaleDiscriminator() + optimizer = torch.optim.Adam( + adv.parameters(), + lr=1e-4, + ) + loss, loss_real, loss_fake = get_adv_criterion('mse'), get_real_criterion('mse'), get_fake_criterion('mse') + adv_loss = AdversarialLoss(adv, optimizer, loss, loss_real, loss_fake) + + B, C, T = 4, 1, random.randint(1000, 5000) + real = torch.randn(B, C, T) + fake = torch.randn(B, C, T) + + disc_loss = adv_loss.train_adv(fake, real) + assert isinstance(disc_loss, torch.Tensor) and isinstance(disc_loss.item(), float) + + loss, loss_feat = adv_loss(fake, real) + assert isinstance(loss, torch.Tensor) and isinstance(loss.item(), float) + # we did not specify feature loss + assert loss_feat.item() == 0. + + def test_adversarial_feat_loss(self): + adv = MultiScaleDiscriminator() + optimizer = torch.optim.Adam( + adv.parameters(), + lr=1e-4, + ) + loss, loss_real, loss_fake = get_adv_criterion('mse'), get_real_criterion('mse'), get_fake_criterion('mse') + feat_loss = FeatureMatchingLoss() + adv_loss = AdversarialLoss(adv, optimizer, loss, loss_real, loss_fake, feat_loss) + + B, C, T = 4, 1, random.randint(1000, 5000) + real = torch.randn(B, C, T) + fake = torch.randn(B, C, T) + + loss, loss_feat = adv_loss(fake, real) + + assert isinstance(loss, torch.Tensor) and isinstance(loss.item(), float) + assert isinstance(loss_feat, torch.Tensor) and isinstance(loss.item(), float) + + +class TestGeneratorAdversarialLoss: + + def test_hinge_generator_adv_loss(self): + adv_loss = get_adv_criterion(loss_type='hinge') + + t0 = torch.randn(1, 2, 0) + t1 = torch.FloatTensor([1.0, 2.0, 3.0]) + + assert adv_loss(t0).item() == 0.0 + assert adv_loss(t1).item() == -2.0 + + def test_mse_generator_adv_loss(self): + adv_loss = get_adv_criterion(loss_type='mse') + + t0 = torch.randn(1, 2, 0) + t1 = torch.FloatTensor([1.0, 1.0, 1.0]) + t2 = torch.FloatTensor([2.0, 5.0, 5.0]) + + assert adv_loss(t0).item() == 0.0 + assert adv_loss(t1).item() == 0.0 + assert adv_loss(t2).item() == 11.0 + + +class TestDiscriminatorAdversarialLoss: + + def _disc_loss(self, loss_type: str, fake: torch.Tensor, real: torch.Tensor): + disc_loss_real = get_real_criterion(loss_type) + disc_loss_fake = get_fake_criterion(loss_type) + + loss = disc_loss_fake(fake) + disc_loss_real(real) + return loss + + def test_hinge_discriminator_adv_loss(self): + loss_type = 'hinge' + t0 = torch.FloatTensor([0.0, 0.0, 0.0]) + t1 = torch.FloatTensor([1.0, 2.0, 3.0]) + + assert self._disc_loss(loss_type, t0, t0).item() == 2.0 + assert self._disc_loss(loss_type, t1, t1).item() == 3.0 + + def test_mse_discriminator_adv_loss(self): + loss_type = 'mse' + + t0 = torch.FloatTensor([0.0, 0.0, 0.0]) + t1 = torch.FloatTensor([1.0, 1.0, 1.0]) + + assert self._disc_loss(loss_type, t0, t0).item() == 1.0 + assert self._disc_loss(loss_type, t1, t0).item() == 2.0 + + +class TestFeatureMatchingLoss: + + def test_features_matching_loss_base(self): + ft_matching_loss = FeatureMatchingLoss() + length = random.randrange(1, 100_000) + t1 = torch.randn(1, 2, length) + + loss = ft_matching_loss([t1], [t1]) + assert isinstance(loss, torch.Tensor) + assert loss.item() == 0.0 + + def test_features_matching_loss_raises_exception(self): + ft_matching_loss = FeatureMatchingLoss() + length = random.randrange(1, 100_000) + t1 = torch.randn(1, 2, length) + t2 = torch.randn(1, 2, length + 1) + + with pytest.raises(AssertionError): + ft_matching_loss([], []) + + with pytest.raises(AssertionError): + ft_matching_loss([t1], [t1, t1]) + + with pytest.raises(AssertionError): + ft_matching_loss([t1], [t2]) + + def test_features_matching_loss_output(self): + loss_nonorm = FeatureMatchingLoss(normalize=False) + loss_layer_normed = FeatureMatchingLoss(normalize=True) + + length = random.randrange(1, 100_000) + t1 = torch.randn(1, 2, length) + t2 = torch.randn(1, 2, length) + + assert loss_nonorm([t1, t2], [t1, t2]).item() == 0.0 + assert loss_layer_normed([t1, t2], [t1, t2]).item() == 0.0 + + t3 = torch.FloatTensor([1.0, 2.0, 3.0]) + t4 = torch.FloatTensor([2.0, 10.0, 3.0]) + + assert loss_nonorm([t3], [t4]).item() == 3.0 + assert loss_nonorm([t3, t3], [t4, t4]).item() == 6.0 + + assert loss_layer_normed([t3], [t4]).item() == 3.0 + assert loss_layer_normed([t3, t3], [t4, t4]).item() == 3.0 diff --git a/tests/common_utils/temp_utils.py b/tests/common_utils/temp_utils.py index d1e0367e979c8b9fea65472c373916d956ad5aaa..b45d896836799edcf1fee271409b390b3b6e4127 100644 --- a/tests/common_utils/temp_utils.py +++ b/tests/common_utils/temp_utils.py @@ -33,7 +33,7 @@ class TempDirMixin: cls.temp_dir_ = None except PermissionError: # On Windows there is a know issue with `shutil.rmtree`, - # which fails intermittenly. + # which fails intermittently. # https://github.com/python/cpython/issues/74168 # Following the above thread, we ignore it. pass diff --git a/tests/common_utils/wav_utils.py b/tests/common_utils/wav_utils.py index d3a563ee1749a58217ece55c9a08b8d93c0fc386..cc14a9caa77af2b0d4cb01c8eedc9bdcb4713996 100644 --- a/tests/common_utils/wav_utils.py +++ b/tests/common_utils/wav_utils.py @@ -5,10 +5,10 @@ # LICENSE file in the root directory of this source tree. from pathlib import Path -import typing as tp import torch -import torchaudio + +from audiocraft.data.audio import audio_write def get_white_noise(chs: int = 1, num_frames: int = 1): @@ -22,11 +22,8 @@ def get_batch_white_noise(bs: int = 1, chs: int = 1, num_frames: int = 1): def save_wav(path: str, wav: torch.Tensor, sample_rate: int): + assert wav.dim() == 2, wav.shape fp = Path(path) - kwargs: tp.Dict[str, tp.Any] = {} - if fp.suffix == '.wav': - kwargs['encoding'] = 'PCM_S' - kwargs['bits_per_sample'] = 16 - elif fp.suffix == '.mp3': - kwargs['compression'] = 320 - torchaudio.save(str(fp), wav, sample_rate, **kwargs) + assert fp.suffix in ['.mp3', '.ogg', '.wav', '.flac'], fp + audio_write(fp.parent / fp.stem, wav, sample_rate, fp.suffix[1:], + normalize=False, strategy='clip', peak_clip_headroom_db=0) diff --git a/tests/data/test_audio_dataset.py b/tests/data/test_audio_dataset.py index b69c9c397830738b73d6c229009f84b867cda801..b591ea6137f48d0d97fcd1243c5f5d258670a474 100644 --- a/tests/data/test_audio_dataset.py +++ b/tests/data/test_audio_dataset.py @@ -313,7 +313,7 @@ class TestAudioDataset(TempDirMixin): def _get_histogram(dataset, repetitions=20_000): counts = {file_meta.path: 0. for file_meta in meta} for _ in range(repetitions): - file_meta = dataset.sample_file(rng) + file_meta = dataset.sample_file(0, rng) counts[file_meta.path] += 1 return {name: count / repetitions for name, count in counts.items()} diff --git a/tests/losses/__init__.py b/tests/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa --- /dev/null +++ b/tests/losses/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/losses/test_losses.py b/tests/losses/test_losses.py new file mode 100644 index 0000000000000000000000000000000000000000..b6681e12c453dea5aeba738ab252d1923b7e0941 --- /dev/null +++ b/tests/losses/test_losses.py @@ -0,0 +1,78 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import random + +import torch + +from audiocraft.losses import ( + MelSpectrogramL1Loss, + MultiScaleMelSpectrogramLoss, + MRSTFTLoss, + SISNR, + STFTLoss, +) + + +def test_mel_l1_loss(): + N, C, T = 2, 2, random.randrange(1000, 100_000) + t1 = torch.randn(N, C, T) + t2 = torch.randn(N, C, T) + + mel_l1 = MelSpectrogramL1Loss(sample_rate=22_050) + loss = mel_l1(t1, t2) + loss_same = mel_l1(t1, t1) + + assert isinstance(loss, torch.Tensor) + assert isinstance(loss_same, torch.Tensor) + assert loss_same.item() == 0.0 + + +def test_msspec_loss(): + N, C, T = 2, 2, random.randrange(1000, 100_000) + t1 = torch.randn(N, C, T) + t2 = torch.randn(N, C, T) + + msspec = MultiScaleMelSpectrogramLoss(sample_rate=22_050) + loss = msspec(t1, t2) + loss_same = msspec(t1, t1) + + assert isinstance(loss, torch.Tensor) + assert isinstance(loss_same, torch.Tensor) + assert loss_same.item() == 0.0 + + +def test_mrstft_loss(): + N, C, T = 2, 2, random.randrange(1000, 100_000) + t1 = torch.randn(N, C, T) + t2 = torch.randn(N, C, T) + + mrstft = MRSTFTLoss() + loss = mrstft(t1, t2) + + assert isinstance(loss, torch.Tensor) + + +def test_sisnr_loss(): + N, C, T = 2, 2, random.randrange(1000, 100_000) + t1 = torch.randn(N, C, T) + t2 = torch.randn(N, C, T) + + sisnr = SISNR() + loss = sisnr(t1, t2) + + assert isinstance(loss, torch.Tensor) + + +def test_stft_loss(): + N, C, T = 2, 2, random.randrange(1000, 100_000) + t1 = torch.randn(N, C, T) + t2 = torch.randn(N, C, T) + + mrstft = STFTLoss() + loss = mrstft(t1, t2) + + assert isinstance(loss, torch.Tensor) diff --git a/tests/models/test_audiogen.py b/tests/models/test_audiogen.py new file mode 100644 index 0000000000000000000000000000000000000000..3850af066cedd5ea38bd9aead9634d6aaf938218 --- /dev/null +++ b/tests/models/test_audiogen.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch + +from audiocraft.models import AudioGen + + +class TestAudioGenModel: + def get_audiogen(self): + ag = AudioGen.get_pretrained(name='debug', device='cpu') + ag.set_generation_params(duration=2.0, extend_stride=2.) + return ag + + def test_base(self): + ag = self.get_audiogen() + assert ag.frame_rate == 25 + assert ag.sample_rate == 16000 + assert ag.audio_channels == 1 + + def test_generate_continuation(self): + ag = self.get_audiogen() + prompt = torch.randn(3, 1, 16000) + wav = ag.generate_continuation(prompt, 16000) + assert list(wav.shape) == [3, 1, 32000] + + prompt = torch.randn(2, 1, 16000) + wav = ag.generate_continuation( + prompt, 16000, ['youpi', 'lapin dort']) + assert list(wav.shape) == [2, 1, 32000] + + prompt = torch.randn(2, 1, 16000) + with pytest.raises(AssertionError): + wav = ag.generate_continuation( + prompt, 16000, ['youpi', 'lapin dort', 'one too many']) + + def test_generate(self): + ag = self.get_audiogen() + wav = ag.generate( + ['youpi', 'lapin dort']) + assert list(wav.shape) == [2, 1, 32000] + + def test_generate_long(self): + ag = self.get_audiogen() + ag.max_duration = 3. + ag.set_generation_params(duration=4., extend_stride=2.) + wav = ag.generate( + ['youpi', 'lapin dort']) + assert list(wav.shape) == [2, 1, 16000 * 4] diff --git a/tests/models/test_multibanddiffusion.py b/tests/models/test_multibanddiffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..2702a3cb5fe402bf96911dbc992d2749cb18a4c0 --- /dev/null +++ b/tests/models/test_multibanddiffusion.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import random + +import numpy as np +import torch +from audiocraft.models.multibanddiffusion import MultiBandDiffusion, DiffusionProcess +from audiocraft.models import EncodecModel, DiffusionUnet +from audiocraft.modules import SEANetEncoder, SEANetDecoder +from audiocraft.modules.diffusion_schedule import NoiseSchedule +from audiocraft.quantization import DummyQuantizer + + +class TestMBD: + + def _create_mbd(self, + sample_rate: int, + channels: int, + n_filters: int = 3, + n_residual_layers: int = 1, + ratios: list = [5, 4, 3, 2], + num_steps: int = 1000, + codec_dim: int = 128, + **kwargs): + frame_rate = np.prod(ratios) + encoder = SEANetEncoder(channels=channels, dimension=codec_dim, n_filters=n_filters, + n_residual_layers=n_residual_layers, ratios=ratios) + decoder = SEANetDecoder(channels=channels, dimension=codec_dim, n_filters=n_filters, + n_residual_layers=n_residual_layers, ratios=ratios) + quantizer = DummyQuantizer() + compression_model = EncodecModel(encoder, decoder, quantizer, frame_rate=frame_rate, + sample_rate=sample_rate, channels=channels, **kwargs) + diffusion_model = DiffusionUnet(chin=channels, num_steps=num_steps, codec_dim=codec_dim) + schedule = NoiseSchedule(device='cpu', num_steps=num_steps) + DP = DiffusionProcess(model=diffusion_model, noise_schedule=schedule) + mbd = MultiBandDiffusion(DPs=[DP], codec_model=compression_model) + return mbd + + def test_model(self): + random.seed(1234) + sample_rate = 24_000 + channels = 1 + codec_dim = 128 + mbd = self._create_mbd(sample_rate=sample_rate, channels=channels, codec_dim=codec_dim) + for _ in range(10): + length = random.randrange(1, 10_000) + x = torch.randn(2, channels, length) + res = mbd.regenerate(x, sample_rate) + assert res.shape == x.shape diff --git a/tests/models/test_musicgen.py b/tests/models/test_musicgen.py index d43cf73763f6c690ab0b277227ac225b286fa143..2b32ac5d52e6ba3ba8f2b413e54e1b5ac5839016 100644 --- a/tests/models/test_musicgen.py +++ b/tests/models/test_musicgen.py @@ -10,7 +10,7 @@ import torch from audiocraft.models import MusicGen -class TestSEANetModel: +class TestMusicGenModel: def get_musicgen(self): mg = MusicGen.get_pretrained(name='debug', device='cpu') mg.set_generation_params(duration=2.0, extend_stride=2.) @@ -56,3 +56,10 @@ class TestSEANetModel: wav = mg.generate( ['youpi', 'lapin dort']) assert list(wav.shape) == [2, 1, 32000 * 4] + + def test_generate_two_step_cfg(self): + mg = self.get_musicgen() + mg.set_generation_params(duration=2.0, extend_stride=2., two_step_cfg=True) + wav = mg.generate( + ['youpi', 'lapin dort']) + assert list(wav.shape) == [2, 1, 64000] diff --git a/tests/modules/test_activations.py b/tests/modules/test_activations.py new file mode 100644 index 0000000000000000000000000000000000000000..24e30d4cd87683430488bfa442e098b34229a5ee --- /dev/null +++ b/tests/modules/test_activations.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn + +from audiocraft.modules.activations import CustomGLU + + +class TestActivations: + def test_custom_glu_calculation(self): + + activation = CustomGLU(nn.Identity()) + + initial_shape = (4, 8, 8) + + part_a = torch.ones(initial_shape) * 2 + part_b = torch.ones(initial_shape) * -1 + input = torch.cat((part_a, part_b), dim=-1) + + output = activation(input) + + # ensure all dimensions match initial shape + assert output.shape == initial_shape + # ensure the gating was calculated correctly a * f(b) + assert torch.all(output == -2).item() diff --git a/tests/modules/test_rope.py b/tests/modules/test_rope.py index 067c6f067acbf27fb0fef5c2b812c22474c4fcd0..ec8d16c08c4925871e20435709674e80cf150349 100644 --- a/tests/modules/test_rope.py +++ b/tests/modules/test_rope.py @@ -11,7 +11,7 @@ from audiocraft.modules.transformer import StreamingTransformer, set_efficient_a def test_rope(): - set_efficient_attention_backend('xformers') + set_efficient_attention_backend('torch') B, T, H, C = 8, 75, 16, 128 rope = RotaryEmbedding(dim=C) @@ -24,7 +24,7 @@ def test_rope(): def test_rope_io_dtypes(): - set_efficient_attention_backend('xformers') + set_efficient_attention_backend('torch') B, T, H, C = 8, 75, 16, 128 rope_32 = RotaryEmbedding(dim=C, dtype=torch.float32) @@ -48,7 +48,7 @@ def test_rope_io_dtypes(): def test_transformer_with_rope(): - set_efficient_attention_backend('xformers') + set_efficient_attention_backend('torch') torch.manual_seed(1234) for pos in ['rope', 'sin_rope']: tr = StreamingTransformer( @@ -64,7 +64,7 @@ def test_transformer_with_rope(): @torch.no_grad() def test_rope_streaming(): - set_efficient_attention_backend('xformers') + set_efficient_attention_backend('torch') torch.manual_seed(1234) tr = StreamingTransformer( 16, 4, 2, causal=True, dropout=0., @@ -92,7 +92,7 @@ def test_rope_streaming(): @torch.no_grad() def test_rope_streaming_past_context(): - set_efficient_attention_backend('xformers') + set_efficient_attention_backend('torch') torch.manual_seed(1234) for context in [None, 10]: @@ -122,7 +122,7 @@ def test_rope_streaming_past_context(): def test_rope_memory_efficient(): - set_efficient_attention_backend('xformers') + set_efficient_attention_backend('torch') torch.manual_seed(1234) tr = StreamingTransformer( 16, 4, 2, custom=True, dropout=0., layer_scale=0.1, @@ -143,7 +143,7 @@ def test_rope_memory_efficient(): def test_rope_with_xpos(): - set_efficient_attention_backend('xformers') + set_efficient_attention_backend('torch') B, T, H, C = 8, 75, 16, 128 rope = RotaryEmbedding(dim=C, xpos=True) @@ -156,7 +156,7 @@ def test_rope_with_xpos(): def test_positional_scale(): - set_efficient_attention_backend('xformers') + set_efficient_attention_backend('torch') B, T, H, C = 8, 75, 16, 128 rope = RotaryEmbedding(dim=C, xpos=True, scale=0.0) diff --git a/tests/modules/test_transformer.py b/tests/modules/test_transformer.py index ff7dfe4c2de05112aec55ddea9c8fd978668f80b..ee74ba06614bd8dafd204ecc84e8fb74527cee69 100644 --- a/tests/modules/test_transformer.py +++ b/tests/modules/test_transformer.py @@ -86,7 +86,7 @@ def test_streaming_api(): def test_memory_efficient(): - for backend in ['torch', 'xformers']: + for backend in ['torch']: torch.manual_seed(1234) set_efficient_attention_backend(backend) @@ -132,7 +132,7 @@ def test_attention_as_float32(): @torch.no_grad() def test_streaming_memory_efficient(): - for backend in ['torch', 'xformers']: + for backend in ['torch']: torch.manual_seed(1234) set_efficient_attention_backend(backend) tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0., custom=True) @@ -173,7 +173,7 @@ def test_cross_attention(): cross_x = torch.randn(2, 3, 16) y_ref = m(x) y_cross_zero = m_cross(x, cross_attention_src=0 * cross_x) - # With norm_first, the two should be exactly yhe same, + # With norm_first, the two should be exactly the same, # but with norm_first=False, we get 2 normalization in a row # and the epsilon value leads to a tiny change. atol = 0. if norm_first else 1e-6