commit 6760f29bd0b1a4fb496e246a269f55da27cd0f1c Author: jason-on-salt-a40 Date: Thu Mar 21 11:02:20 2024 -0700 init diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9135f6d --- /dev/null +++ b/.gitignore @@ -0,0 +1,25 @@ +__pycache__/ +*.py[cod] +*$py.class +*.egg-info +.pytest_cache +.ipynb_checkpoints + +thumbs.db +.DS_Store +.idea +*.log +*.pdf +*.mkv +*.mp4 +*.png +*.wav +*.mp3 + +*durip* +*rtx* +*l40* +*a40* + +!/demo/ +!/demo/* \ No newline at end of file diff --git a/LICENSE-CODE b/LICENSE-CODE new file mode 100644 index 0000000..cbe5ad1 --- /dev/null +++ b/LICENSE-CODE @@ -0,0 +1,437 @@ +Attribution-NonCommercial-ShareAlike 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International +Public License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial-ShareAlike 4.0 International Public License +("Public License"). To the extent this Public License may be +interpreted as a contract, You are granted the Licensed Rights in +consideration of Your acceptance of these terms and conditions, and the +Licensor grants You such rights in consideration of benefits the +Licensor receives from making the Licensed Material available under +these terms and conditions. + + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. BY-NC-SA Compatible License means a license listed at + creativecommons.org/compatiblelicenses, approved by Creative + Commons as essentially the equivalent of this Public License. + + d. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + + e. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + f. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + g. License Elements means the license attributes listed in the name + of a Creative Commons Public License. The License Elements of this + Public License are Attribution, NonCommercial, and ShareAlike. + + h. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + i. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + j. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + k. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + l. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + m. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + n. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. Additional offer from the Licensor -- Adapted Material. + Every recipient of Adapted Material from You + automatically receives an offer from the Licensor to + exercise the Licensed Rights in the Adapted Material + under the conditions of the Adapter's License You apply. + + c. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + b. ShareAlike. + + In addition to the conditions in Section 3(a), if You Share + Adapted Material You produce, the following conditions also apply. + + 1. The Adapter's License You apply must be a Creative Commons + license with the same License Elements, this version or + later, or a BY-NC-SA Compatible License. + + 2. You must include the text of, or the URI or hyperlink to, the + Adapter's License You apply. You may satisfy this condition + in any reasonable manner based on the medium, means, and + context in which You Share Adapted Material. + + 3. You may not offer or impose any additional or different terms + or conditions on, or apply any Effective Technological + Measures to, Adapted Material that restrict exercise of the + rights granted under the Adapter's License You apply. + + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material, + including for purposes of Section 3(b); and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. diff --git a/LICENSE-MODEL b/LICENSE-MODEL new file mode 100644 index 0000000..d02930f --- /dev/null +++ b/LICENSE-MODEL @@ -0,0 +1,42 @@ +Coqui Public Model License 1.0.0 +https://coqui.ai/cpml.txt + +This license allows only non-commercial use of a machine learning model and its outputs. + +Acceptance +In order to get any license under these terms, you must agree to them as both strict obligations and conditions to all your licenses. + +Licenses +The licensor grants you a copyright license to do everything you might do with the model that would otherwise infringe the licensor's copyright in it, for any non-commercial purpose. The licensor grants you a patent license that covers patent claims the licensor can license, or becomes able to license, that you would infringe by using the model in the form provided by the licensor, for any non-commercial purpose. + +Non-commercial Purpose +Non-commercial purposes include any of the following uses of the model or its output, but only so far as you do not receive any direct or indirect payment arising from the use of the model or its output. + +Personal use for research, experiment, and testing for the benefit of public knowledge, personal study, private entertainment, hobby projects, amateur pursuits, or religious observance. +Use by commercial or for-profit entities for testing, evaluation, or non-commercial research and development. Use of the model to train other models for commercial use is not a non-commercial purpose. +Use by any charitable organization for charitable purposes, or for testing or evaluation. Use for revenue-generating activity, including projects directly funded by government grants, is not a non-commercial purpose. +Notices +You must ensure that anyone who gets a copy of any part of the model, or any modification of the model, or their output, from you also gets a copy of these terms or the URL for them above. + +No Other Rights +These terms do not allow you to sublicense or transfer any of your licenses to anyone else, or prevent the licensor from granting licenses to anyone else. These terms do not imply any other licenses. + +Patent Defense +If you make any written claim that the model infringes or contributes to infringement of any patent, your licenses for the model granted under these terms ends immediately. If your company makes such a claim, your patent license ends immediately for work on behalf of your company. + +Violations +The first time you are notified in writing that you have violated any of these terms, or done anything with the model or its output that is not covered by your licenses, your licenses can nonetheless continue if you come into full compliance with these terms, and take practical steps to correct past violations, within 30 days of receiving notice. Otherwise, all your licenses end immediately. + +No Liability +AS FAR AS THE LAW ALLOWS, THE MODEL AND ITS OUTPUT COME AS IS, WITHOUT ANY WARRANTY OR CONDITION, AND THE LICENSOR WILL NOT BE LIABLE TO YOU FOR ANY DAMAGES ARISING OUT OF THESE TERMS OR THE USE OR NATURE OF THE MODEL OR ITS OUTPUT, UNDER ANY KIND OF LEGAL CLAIM. IF THIS PROVISION IS NOT ENFORCEABLE IN YOUR JURISDICTION, YOUR LICENSES ARE VOID. + +Definitions +The licensor is the individual or entity offering these terms, and the model is the model the licensor makes available under these terms, including any documentation or similar information about the model. + +You refers to the individual or entity agreeing to these terms. + +Your company is any legal entity, sole proprietorship, or other kind of organization that you work for, plus all organizations that have control over, are under the control of, or are under common control with that organization. Control means ownership of substantially all the assets of an entity, or the power to direct its management and policies by vote, contract, or otherwise. Control can be direct or indirect. + +Your licenses are all the licenses granted to you under these terms. + +Use means anything you do with the model or its output requiring one of your licenses. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..049bab3 --- /dev/null +++ b/README.md @@ -0,0 +1,69 @@ +# VoiceCraft: Zero-Shot Speech Editing and Text-to-Speech in the Wild +[Demo](https://jasonppy.github.io/VoiceCraft_web) [Paper](https://jasonppy.github.io/assets/pdfs/VoiceCraft.pdf) + +TL;DR: +VoiceCraft is a token infilling neural codec language model, that achieves state-of-the-art performance on both **speech editing** and **zero-shot text-to-speech (TTS)** on in-the-wild data including audiobooks, internet videos, and podcasts. + +To clone or edit an unseen voice, VoiceCraft needs only a few seconds of reference. + + +## TODO +The TODOs left will be completed by the end of March 2024. +- [x] Codebase upload +- [x] Environment setup +- [x] Inference demo for speech editing and TTS +- [] Upload model weights +- [] Training guidance +- [] Upload the RealEdit dataset + +## Environment setup +```bash +conda create -n voicecraft python=3.9.16 +conda activate voicecraft + +pip install torch==2.0.1 torchaudio==2.0.2 # this assumes your system is compatible with CUDA 11.7, otherwise checkout https://pytorch.org/get-started/previous-versions/#v201 +apt-get install ffmpeg # if you don't already have ffmpeg installed +pip install -e git+https://github.com/facebookresearch/audiocraft.git@c5157b5bf14bf83449c17ea1eeb66c19fb4bc7f0#egg=audiocraft +apt-get install espeak-ng # backend for the phonemizer installed below +pip install phonemizer==3.2.1 +pip install tensorboard +pip install datasets==2.12.0 +# install MFA for getting forced-alignment, this could take a few minutes +conda install -c conda-forge montreal-forced-aligner=2.2.17 openfst=1.8.2 kaldi=5.5.1068 +# conda install pocl # above gives an warning for installing pocl, not sure if really need this + +# to run ipynb +conda install -n voicecraft ipykernel --update-deps --force-reinstall +``` + +## Inference Examples +Checkout [`inference_speech_editing.ipynb`](./inference_speech_editing.ipynb) and [`inference_tts.ipynb`](./inference_tts.ipynb) + +## License +The codebase is under CC BY-NC-SA 4.0 ([LICENSE-CODE](./LICENSE-CODE)), and the model weights are under Coqui Public Model License 1.0.0 ([LICENSE-MODEL](./LICENSE-MODEL)). Note that we use some of the code from other repository that are under different licenses: `./models/codebooks_patterns.py` is under MIT license; `./models/modules`, `./steps/optim.py`, `data/tokenizer.py` are under Apache License, Version 2.0; the phonemizer we used is under GNU 3.0 License. For drop-in replacement of the phonemizer (i.e. text to IPA phoneme mapping), try [g2p](https://github.com/roedoejet/g2p) (MIT License) or [OpenPhonemizer](https://github.com/NeuralVox/OpenPhonemizer) (BSD-3-Clause Clear), although these are not tested. + + + +## Acknowledgement +We thank Feiteng for his [VALL-E reproduction](https://github.com/lifeiteng/vall-e), and we thank audiocraft team for open-sourcing [encodec](https://github.com/facebookresearch/audiocraft). + +## Citation +``` +@article{peng2024voicecraft, + author = {Peng, Puyuan and Huang, Po-Yao and Li, Daniel and Mohamed, Abdelrahman and Harwath, David}, + title = {VoiceCraft: Zero-Shot Speech Editing and Text-to-Speech in the Wild}, + journal = {arXiv}, + year = {2024}, +} +``` + +## Disclaimer +Any organization or individual is prohibited from using any technology mentioned in this paper to generate or edit someone's speech without his/her consent, including but not limited to government leaders, political figures, and celebrities. If you do not comply with this item, you could be in violation of copyright laws. + diff --git a/config.py b/config.py new file mode 100644 index 0000000..466c6ad --- /dev/null +++ b/config.py @@ -0,0 +1,86 @@ +import argparse + + +def MyParser(): + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + # general training + parser.add_argument("--seed", type=int, default=1) + parser.add_argument("--precision", type=str, default="float16") + parser.add_argument("--num_workers", type=int, default=8) + parser.add_argument("--resume", action="store_true", default=False) + parser.add_argument("--tb_write_every_n_steps", type=int, default=100) + parser.add_argument("--print_every_n_steps", type=int, default=400) + parser.add_argument("--val_every_n_steps", type=int, default=800) + parser.add_argument("--lr", type=float, default=0.05) + parser.add_argument("--batch_size", type=int, default=100, help="this is the effective batch size, no matter whether using gradient_accumulation_steps, not used if we specified max_num_tokens") + parser.add_argument("--max_num_tokens", type=int, default=100000, help="max number of encodec tokens per gpu, this is only used when using dynamic batching, will ignore batch size. Note this is the final effective batch size per GPU, i.e. gradient accumulated batch size per gpu") + parser.add_argument("--val_max_num_tokens", type=int, default=None, help="FOR validation") + parser.add_argument("--num_buckets", type=int, default=6, help='used for dynamic batching, bucketing the samples based on the number of tokens') + parser.add_argument("--dynamic_batching", type=int, default=0) + parser.add_argument("--weight_decay", type=float, default=1e-2) + parser.add_argument("--warmup_fraction", type=float, default=0.01, help="use linear warmup, the proportion of the training steps that are used for warming up") + parser.add_argument("--num_epochs", type=int, default=10) + parser.add_argument("--num_steps", type=int, default=None, help="if not None, will ignore n_epochs and use num_steps as the total number of amount of training, can try e.g. 400000 i.e. 400k steps") + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--gradient_clip_val", type=float, default=1.0, help="the value for torch.nn.utils.clip_grad_norm_(), not used if we use ScaledAdam optimizer") + parser.add_argument("--early_stop_step", type=int, default=3200, help="stop training after this many steps of non-improvement") + parser.add_argument("--early_stop_threshold", type=float, default=-1.0, help="early stop after the improvement is below this threshold for certain number of steps") + + # optimizer focused + parser.add_argument("--optimizer_name", type=str, default="AdamW", help="can also use ScaledAdam, in which case we'll also use the Eden scheduler") + parser.add_argument("--reduce_lr_start_step", type=int, default=3000, help='after which significantly reduce the lr. a param for the eden optimizer') + parser.add_argument("--pseudo_epoch_size", type=int, default=3000, help="only use for Eden scheduler.") + parser.add_argument("--reduce_lr_start_epoch", type=int, default=4) + parser.add_argument("--clipping_update_period", type=int, default=600) + + + # path + parser.add_argument("--exp_dir", type=str, default=None, help="will be combined with dataset name") + parser.add_argument("--dataset", type=str, help="e.g. 'libritts', 'gigaspeech', they are folder name in the data dir also") + parser.add_argument("--dataset_dir", type=str, help="need to be compatible with corresponding dataset py file") + parser.add_argument("--phn_folder_name", type=str, default="phonemes", help="for libritts I also have arpa phns, in which case should be phonemes_arpa") + parser.add_argument("--encodec_folder_name", type=str, default="encodec_16khz_4codebooks", help="folder where encodec codes are stored") + parser.add_argument("--manifest_name", type=str, default="manifest", help="metadata filename") + + # data focused + parser.add_argument("--pad_x", type=int, default=1, help="whether or not always pad x to have text_max_length. select 1 to get the maximal memory consumption, but the actual case should be smaller, better to have it being 0") + parser.add_argument("--audio_max_length", type=float, default=20, help="in second, crop or drop the audio is length is longer than this") + parser.add_argument("--audio_min_length", type=float, default=2, help="in second, drop the audio if length is shorter than this") + parser.add_argument("--text_max_length", type=int, default=400, help='if too long, we crop or drop') + parser.add_argument("--text_min_length", type=float, default=10, help="if too short, will drop") + parser.add_argument("--encodec_sr", type=int, default=50, help="for my encodec that takes 16kHz audio with a downsample rate of 320, the codec sample rate is 50Hz, i.e. 50 codes (x n_codebooks) per second") + parser.add_argument("--drop_long", type=int, default=0, help="if this is true, will drop example whose encodec sequence or phone sequence is too long, rather than cropping, to reduce hellucination") + + # encodec and token rearrangement + parser.add_argument('--mask_len_min', type=int, default=1, help='Minimum mask length') + parser.add_argument('--mask_len_max', type=int, default=600, help='Maximum mask length') + parser.add_argument("--eos", type=int, default=-1, help="this is to be used with reduced_eog, where we end the utterance with eos, and end the generated segment with eog, also when this is used, the n_special should be 4") + parser.add_argument("--reduced_eog", type=int, default=0, help="for the non-final segments, do not insert eog at the end, this could hopefully solve the early stopping issue when doing tts") + parser.add_argument("--special_first", type=int, default=0, help="if 1, need to have special tokens to be the first few tokens, e.g. 0, 1, 2, which means we need to adjust the preprocessing and postprocessing of the encodec codes. note that we hard coded to have 3 special tokens") + parser.add_argument("--n_special", type=int, default=3, help="empty, eog, pad, (eos)") + parser.add_argument("--codebook_weight", type=str, default=None, help="e.g. ['5','1','0.5','0.1']") + parser.add_argument("--max_mask_portion",type=float,default=0.7,help="should mask a utterance for more than this portion") + parser.add_argument("--max_n_spans", type=int, default=3, help='maximal number of spans, only use when using multicm3, this is used to decide number of mask_embedding, and max clamp value if use Poisson distribution, if use uniform distribution to sample number of spans if will be uniform(1,max_n_spans)') + parser.add_argument("--shuffle_mask_embedding", type=int, default=0, help="whether shuffle the mask embedding, so that mask:0 is not the most well trained, default is not shuffling. The default has it's benefit, as it make sure that mask:0 always appear the first") + parser.add_argument("--mask_sample_dist", type=str, default="poisson1", help="uniform or poissonx, e.g. poisson1, meaning the parameter lambda is 1, it will most likely sample 1 masks") + parser.add_argument("--min_gap", type=int, default=5, help="after sampled starts, delete later one if it closer to the former start than the min_gap") + parser.add_argument('--n_codebooks', type=int, default=4) + parser.add_argument('--text_vocab_size', type=int, default=100, help='Size of text vocabulary') + parser.add_argument('--text_pad_token', type=int, default=100, help='padding of the text tokens, not attended') + parser.add_argument('--audio_vocab_size', type=str, default='2048', help="Size of audio vocabulary") + parser.add_argument("--empty_token", default=2048, type=int, help="indicating the no token at the position for the codebook") + parser.add_argument('--eog', type=int, default=2049, help='End of generation token') + parser.add_argument('--audio_pad_token', type=int, default=2050, help='padding of the encodec codes, not attended') + + # model focused + parser.add_argument('--d_model', type=int, default=2048, help='Model dimension') + parser.add_argument('--audio_embedding_dim', type=int, default=2048, help='dimension for encodec continues embedding (before being quantized)') + parser.add_argument('--text_embedding_dropout', type=float, default=0.1, help='Dropout for text embedding') + parser.add_argument('--audio_embedding_dropout', type=float, default=0, help='Dropout for audio embedding') + parser.add_argument('--text_positional_embedding_dropout', type=float, default=0.1, help='Dropout for text positional embedding') + parser.add_argument('--audio_positional_embedding_dropout', type=float, default=0.1, help='Dropout for audio positional embedding') + parser.add_argument('--trm_dropout', type=float, default=0.1, help='Dropout for transformer') + parser.add_argument('--nhead', type=int, default=16, help='Number of attention heads') + parser.add_argument('--num_decoder_layers', type=int, default=16, help='Number of decoder layers') + parser.add_argument('--load_model_from', type=str, default=None, help='Path to load model from, this will be effective last, so will overwrite all previous load, including resume') + return parser \ No newline at end of file diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data/giga_preprocessing/encodec_encode.py b/data/giga_preprocessing/encodec_encode.py new file mode 100644 index 0000000..f2a9915 --- /dev/null +++ b/data/giga_preprocessing/encodec_encode.py @@ -0,0 +1,160 @@ +import argparse +def parse_args(): + parser = argparse.ArgumentParser(description="encode the librilight dataset using encodec model") + parser.add_argument("--manifest_root", type=str, default="/home/pyp/audiocraft/egs/gigaspeech", help="this the dir of the audiocraft manifest!") + parser.add_argument('--audio_dir', type=str, default="/data/scratch/pyp/datasets/gigaspeech_flac", help="Path dirs of the flac audio files") + parser.add_argument('--save_dir', type=str, default="/data/scratch/pyp/datasets/gigaspeech_phn_enc_manifest/xl", help="path to the manifest, phonemes, and encodec codes dirs") + parser.add_argument('--encodec_model_path', type=str, default="/data/scratch/pyp/exp_pyp/audiocraft/encodec/xps/6f79c6a8/checkpoint.th") + parser.add_argument('--n_workers', type=int, default=32, help="Number of parallel worker processes") + parser.add_argument('--batch_size', type=int, default=64, help="batch size for encodec encoding, decrease it if OOM. This is the sum of batch size *over each gpu*, so increase it if you are using more gpus") + parser.add_argument('--model_sr', type=int, default=16000, help='encodec input audio sample rate') + parser.add_argument('--downsample_rate', type=int, default=320, help='encodec downsample rate') + parser.add_argument('--model_code_sr', type=int, default=50, help='encodec model code sample rate') + parser.add_argument('--len_cap', type=float, default=35.0, help='will drop audios that are longer than this number') + return parser.parse_args() + +if __name__ == "__main__": + import logging + formatter = ( + "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s" + ) + logging.basicConfig(format=formatter, level=logging.INFO) + + import os + import numpy as np + import torch + import torchaudio + import tqdm + import time + + args = parse_args() + + manifest_dir = args.manifest_root # this dir is scp-ed + audio_dir = args.audio_dir # this is scp-ed flac dir + encodec_signature = args.encodec_model_path.split("/")[-2] + save_codes_dir = os.path.join(args.save_dir, f"encodec_16khz_{encodec_signature}") + os.makedirs(save_codes_dir, exist_ok=True) + + + # model_sr = 16000 + # downsample_rate = 320 + # model_code_sr = 50 + def sort_by_audio_len(lens): + inds = np.argsort(lens).tolist() + logging.info(f"longest: {lens[inds[-1]]/args.downsample_rate} encodec codes, {lens[inds[-1]]/args.model_sr:.2f} sec.") + logging.info(f"shortest: {lens[inds[0]]/args.downsample_rate} encodec codes, {lens[inds[0]]/args.model_sr:.2f} sec.") + logging.info(f"median: {lens[inds[len(inds)//2]]/args.downsample_rate} encodec codes, {lens[inds[len(inds)//2]]/args.model_sr:.2f} sec.") + logging.info(f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]/args.downsample_rate} encodec codes, {lens[inds[int(len(inds)*0.95)]]/args.model_sr:.2f} sec.") + return inds[::-1] + + def write_array_to_txt_file(array, filename): + with open(filename, 'w') as f: + for a in array[:-1]: + f.write(' '.join(map(str, a))+'\n') + f.write(' '.join(map(str, array[-1]))) + + + + class mydataset(torch.utils.data.Dataset): + def __init__(self, split): + super().__init__() + # self.data = gs[split] + self.split = split + self.audio_root = audio_dir + manifest_fn = os.path.join(manifest_dir, split+".txt") + with open(manifest_fn, "r") as rf: + self.data = [l.strip().split("\t") for l in rf.readlines()] + def __len__(self): + return len(self.data) + def __getitem__(self, ind): + try: + afn = self.data[ind][0] + fn = os.path.join(self.audio_root, afn) + audio, sr = torchaudio.load(fn) + assert sr == args.model_sr, sr + except Exception as e: + logging.info(f"{e}") + return None, None, None + assert audio.ndim==2 and audio.shape[0] == 1, audio.shape + return audio.type(torch.float32).squeeze(0), audio.shape[-1], os.path.basename(afn).split(".")[0] + def collate(self, batch): + lens, audios, segment_ids = [], [], [] + for item in batch: + if item[0] != None: + audios.append(item[0]) + lens.append(item[1]) + segment_ids.append(item[2]) + return audios, lens, segment_ids + + # load the encodec model + from audiocraft.solvers import CompressionSolver + model = CompressionSolver.model_from_checkpoint(args.encodec_model_path) + model = model.cuda() + model = model.eval() + model = torch.nn.DataParallel(model) + + + # setup dataloader + mega_batch_size = 2100 + batch_size = args.batch_size + train_dataset = mydataset('train') + train_loader = torch.torch.utils.data.DataLoader(train_dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=train_dataset.collate) + validation_dataset = mydataset('validation') + validation_loader = torch.torch.utils.data.DataLoader(validation_dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=validation_dataset.collate) + test_dataset = mydataset('test') + test_loader = torch.torch.utils.data.DataLoader(test_dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=test_dataset.collate) + splits = ['validation', 'test', 'train'] + loaders = [validation_loader, test_loader, train_loader] + # splits = ['validation'] # NOTE this is for debug, for example, see if the + # loaders = [validation_loader] + for split, loader in zip(splits, loaders): + skip = 0 + logging.info(f"now processing split {split}...") + mega_n_steps = int(np.ceil(len(loader.dataset) / mega_batch_size)) + # mega_n_steps = int(np.ceil(len(gs) / mega_batch_size)) + logging.info(f"partition the split {split} into {mega_n_steps} parts, each has {mega_batch_size} samples") + # with open(mani_fn, "a") as mani_wf: # resume from where we failed + for m, mega_batch in enumerate(loader): + logging.info(f"====================================") + logging.info(f"====================================") + logging.info(f"now processing mega step {m+1}/{mega_n_steps}") + lengths = np.array(mega_batch[1]) + sorted_inds = sort_by_audio_len(lengths) + for j in range(len(sorted_inds))[::-1]: + if lengths[sorted_inds[j]] < args.model_sr*0.2 or lengths[sorted_inds[j]] > args.model_sr*args.len_cap: # skip samples that are too short (shorter than 0.2s), or too big (bigger than 80s) + skip += 1 + del sorted_inds[j] + + n_steps = int(np.ceil(len(sorted_inds) / batch_size)) + for n in tqdm.tqdm(range(n_steps), disable=True): + inds_used = sorted_inds[n*batch_size:(n+1)*batch_size] + wav_batch = [mega_batch[0][id] for id in inds_used] + all_lens = [mega_batch[1][id] for id in inds_used] + segment_id_batch = [mega_batch[2][id] for id in inds_used] + # print(segment_id_batch) + padded_wav = torch.nn.utils.rnn.pad_sequence(wav_batch, batch_first=True).unsqueeze(1) # [B, T] -> [B, 1, T] + with torch.no_grad(): + if max(all_lens) > 300000 and len(all_lens) > 1: # NOTE decrease this (300000) if OOM, or chunk it into more than 2 forward passes + codes = [] + inwav = padded_wav.cuda() + codes.append(model(inwav[:len(inwav)//2], encode=True)[0].cpu()) + codes.append(model(inwav[len(inwav)//2:], encode=True)[0].cpu()) + codes = torch.cat(codes, dim=0) + else: + encoded_frames = model(padded_wav.cuda(), encode=True) # wav needs to have shape [B, C, T], C is model.channels, which is 1 for the 24kHz encodec model + # logging.info(f"encoded_frames: {encoded_frames[0].shape}") + codes = encoded_frames[0].cpu() + + for i, length in enumerate(all_lens): + save_fn = os.path.join(save_codes_dir, segment_id_batch[i]+".txt") + actual_len = round(length / args.downsample_rate) # 320 is downsample rate for this model + cur_code = codes[i].tolist() if type(codes) == list else codes[i, :, :actual_len].tolist() + write_array_to_txt_file(cur_code, save_fn) + + # mani_wf.write(f"0\t{segment_id_batch[i]}\t{len(cur_code[0])}\n") # write to manifest file + # if i == 10: + # raise + # break + # logging.info(f"split {split} has {len(gs[split])} samples in total, skipped {skip} due to forbiden words") + logging.info(f"split {split} has {len(loader.dataset)} samples in total, skipped {skip} due to utterance being too long or too short") + # break \ No newline at end of file diff --git a/data/gigaspeech.py b/data/gigaspeech.py new file mode 100644 index 0000000..0d855a6 --- /dev/null +++ b/data/gigaspeech.py @@ -0,0 +1,158 @@ +import os +import torch +import random +import copy +import logging +import shutil + +class dataset(torch.utils.data.Dataset): + def __init__(self, args, split): + super().__init__() + self.args = args + self.split = split + assert self.split in ['train', 'validation', 'test'] + manifest_fn = os.path.join(self.args.dataset_dir, self.args.manifest_name, self.split+".txt") + + with open(manifest_fn, "r") as rf: + data = [l.strip().split("\t") for l in rf.readlines()] + lengths_list = [int(item[-1]) for item in data] + self.data = [] + self.lengths_list = [] + for d, l in zip(data, lengths_list): + if l >= self.args.encodec_sr*self.args.audio_min_length: + if self.args.drop_long and l > self.args.encodec_sr*self.args.audio_max_length: + continue + self.data.append(d) + self.lengths_list.append(l) + logging.info(f"number of data points for {self.split} split: {len(self.lengths_list)}") + + # phoneme vocabulary + vocab_fn = os.path.join(self.args.dataset_dir,"vocab.txt") + shutil.copy(vocab_fn, os.path.join(self.args.exp_dir, "vocab.txt")) + with open(vocab_fn, "r") as f: + temp = [l.strip().split(" ") for l in f.readlines() if len(l) != 0] + self.phn2num = {item[1]:int(item[0]) for item in temp} + + self.symbol_set = set(["", "", "", ""]) + + def __len__(self): + return len(self.lengths_list) + + def _load_phn_enc(self, index): + item = self.data[index] + pf = os.path.join(self.args.dataset_dir, self.args.phn_folder_name, item[1]+".txt") + ef = os.path.join(self.args.dataset_dir, self.args.encodec_folder_name, item[1]+".txt") + try: + with open(pf, "r") as p, open(ef, "r") as e: + phns = [l.strip() for l in p.readlines()] + assert len(phns) == 1, phns + x = [self.phn2num[item] for item in phns[0].split(" ") if item not in self.symbol_set] # drop ["", "", "", ""], as they are not in training set annotation + encos = [l.strip().split() for k, l in enumerate(e.readlines()) if k < self.args.n_codebooks] + + assert len(encos) == self.args.n_codebooks, ef + if self.args.special_first: + y = [[int(n)+self.args.n_special for n in l] for l in encos] + else: + y = [[int(n) for n in l] for l in encos] + if self.args.training_stage == 1 and not self.args.valle and not (self.args.musicgen or self.args.valle_orig): + y = y[:1] + except Exception as e: + logging.info(f"loading failed for {pf} and {ef}, maybe files don't exist or are corrupted") + logging.info(f"error message: {e}") + return [], [[]] + + return x, y + + def __getitem__(self, index): + x, y = self._load_phn_enc(index) + x_len, y_len = len(x), len(y[0]) + + if x_len == 0 or y_len == 0: + return { + "x": None, + "x_len": None, + "y": None, + "y_len": None, + "y_mask_interval": None, # index y_mask_interval[1] is the position of start_of_continue token + "extra_mask_start": None # this is only used in VE1 + } + while y_len < self.args.encodec_sr*self.args.audio_min_length: + assert not self.args.dynamic_batching + index = random.choice(range(len(self))) # regenerate an index + x, y = self._load_phn_enc(index) + x_len, y_len = len(x), len(y[0]) + if self.args.drop_long: + while x_len > self.args.text_max_length or y_len > self.args.encodec_sr*self.args.audio_max_length: + index = random.choice(range(len(self))) # regenerate an index + x, y = self._load_phn_enc(index) + x_len, y_len = len(x), len(y[0]) + + ### padding and cropping below ### + ### padding and cropping below ### + # adjust the length of encodec codes, pad to max_len or randomly crop + orig_y_len = copy.copy(y_len) + max_len = int(self.args.audio_max_length * self.args.encodec_sr) + if y_len > max_len: + audio_start = random.choice(range(0, y_len-max_len)) + for i in range(len(y)): + y[i] = y[i][audio_start:(audio_start+max_len)] + y_len = max_len + else: + audio_start = 0 + if not self.args.dynamic_batching: + pad = [0] * (max_len - y_len) if self.args.sep_special_token else [self.args.audio_pad_token] * (max_len - y_len) + for i in range(len(y)): + y[i] = y[i] + pad + + # adjust text + # if audio is cropped, and text is longer than max, crop max based on how audio is cropped + if audio_start > 0 and len(x) > self.args.text_max_length: # if audio is longer than max and text is long than max, start text the way audio started + x = x[int(len(x)*audio_start/orig_y_len):] + if len(x) > self.args.text_max_length: # if text is still longer than max, cut the end + x = x[:self.args.text_max_length] + + x_len = len(x) + if x_len > self.args.text_max_length: + text_start = random.choice(range(0, x_len - self.args.text_max_length)) + x = x[text_start:text_start+self.args.text_max_length] + x_len = self.args.text_max_length + elif self.args.pad_x and x_len <= self.args.text_max_length: + pad = [0] * (self.args.text_max_length - x_len) if self.args.sep_special_token else [self.args.text_pad_token] * (self.args.text_max_length - x_len) + x = x + pad + ### padding and cropping above ### + ### padding and cropping above ### + + return { + "x": torch.LongTensor(x), + "x_len": x_len, + "y": torch.LongTensor(y), + "y_len": y_len + } + + + def collate(self, batch): + out = {key:[] for key in batch[0]} + for item in batch: + if item['x'] == None: # deal with load failure + continue + for key, val in item.items(): + out[key].append(val) + res = {} + if self.args.pad_x: + res["x"] = torch.stack(out["x"], dim=0) + else: + res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=0 if self.args.sep_special_token else self.args.text_pad_token) + res["x_lens"] = torch.LongTensor(out["x_len"]) + if self.args.dynamic_batching: + if out['y'][0].ndim==2: + res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=0 if self.args.sep_special_token else self.args.audio_pad_token) + res['y'] = res['y'].permute(1,2,0) # T B K -> B K T + else: + assert out['y'][0].ndim==1, out['y'][0].shape + res['y'] = torch.nn.utils.rnn.pad_sequence(out['y'], batch_first=True, padding_value=0 if self.args.sep_special_token else self.args.audio_pad_token) + else: + res['y'] = torch.stack(out['y'], dim=0) + res["y_lens"] = torch.LongTensor(out["y_len"]) + res["text_padding_mask"] = torch.arange(res['x'][0].shape[-1]).unsqueeze(0) >= res['x_lens'].unsqueeze(1) + res["audio_padding_mask"] = torch.arange(res['y'][0].shape[-1]).unsqueeze(0) >= res['y_lens'].unsqueeze(1) + return res \ No newline at end of file diff --git a/data/tokenizer.py b/data/tokenizer.py new file mode 100644 index 0000000..1495120 --- /dev/null +++ b/data/tokenizer.py @@ -0,0 +1,149 @@ +# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/data/tokenizer.py +# Copyright 2023 (authors: Feiteng Li) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from dataclasses import asdict, dataclass +from typing import Any, Dict, List, Optional, Pattern, Union + +import numpy as np +import torch +import torchaudio +# from lhotse.features import FeatureExtractor +# from lhotse.utils import Seconds, compute_num_frames +from phonemizer.backend import EspeakBackend +from phonemizer.backend.espeak.language_switch import LanguageSwitch +from phonemizer.backend.espeak.words_mismatch import WordMismatch +from phonemizer.punctuation import Punctuation +from phonemizer.separator import Separator + + + +class TextTokenizer: + """Phonemize Text.""" + + def __init__( + self, + language="en-us", + backend="espeak", + separator=Separator(word="_", syllable="-", phone="|"), + preserve_punctuation=True, + punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(), + with_stress: bool = False, + tie: Union[bool, str] = False, + language_switch: LanguageSwitch = "keep-flags", + words_mismatch: WordMismatch = "ignore", + ) -> None: + phonemizer = EspeakBackend( + language, + punctuation_marks=punctuation_marks, + preserve_punctuation=preserve_punctuation, + with_stress=with_stress, + tie=tie, + language_switch=language_switch, + words_mismatch=words_mismatch, + ) + + self.backend = phonemizer + self.separator = separator + + def to_list(self, phonemized: str) -> List[str]: + fields = [] + for word in phonemized.split(self.separator.word): + # "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z. + pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE) + fields.extend( + [p for p in pp if p != self.separator.phone] + + [self.separator.word] + ) + assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count( + self.separator.phone + ) + return fields[:-1] + + def __call__(self, text, strip=True) -> List[List[str]]: + if isinstance(text, str): + text = [text] + + phonemized = self.backend.phonemize( + text, separator=self.separator, strip=strip, njobs=1 + ) + return [self.to_list(p) for p in phonemized] + + +def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]: + phonemes = tokenizer([text.strip()]) + return phonemes[0] # k2symbols + +def convert_audio(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int): + assert wav.shape[0] in [1, 2], "Audio must be mono or stereo." + if target_channels == 1: + wav = wav.mean(0, keepdim=True) + elif target_channels == 2: + *shape, _, length = wav.shape + wav = wav.expand(*shape, target_channels, length) + elif wav.shape[0] == 1: + wav = wav.expand(target_channels, -1) + wav = torchaudio.transforms.Resample(sr, target_sr)(wav) + return wav + +class AudioTokenizer: + """EnCodec audio.""" + + def __init__( + self, + device: Any = None, + signature = None + ) -> None: + from audiocraft.solvers import CompressionSolver + model = CompressionSolver.model_from_checkpoint(signature) + self.sample_rate = model.sample_rate + self.channels = model.channels + + if not device: + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda:0") + + self._device = device + + self.codec = model.to(device) + + @property + def device(self): + return self._device + + def encode(self, wav: torch.Tensor) -> torch.Tensor: + codes = self.codec.encode(wav.to(self.device)) + return [(codes[0], None)] + + def decode(self, frames: torch.Tensor) -> torch.Tensor: + frames = frames[0][0] # [1,4,T] + return self.codec.decode(frames) + + + +def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str, offset = -1, num_frames=-1): + # Load and pre-process the audio waveform + if offset != -1 and num_frames!=-1: + wav, sr = torchaudio.load(audio_path, frame_offset=offset, num_frames=num_frames) + else: + wav, sr = torchaudio.load(audio_path) + wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels) + wav = wav.unsqueeze(0) + + # Extract discrete codes from EnCodec + with torch.no_grad(): + encoded_frames = tokenizer.encode(wav) + return encoded_frames diff --git a/demo/84_121550_000074_000000.wav b/demo/84_121550_000074_000000.wav new file mode 100644 index 0000000..ed16974 Binary files /dev/null and b/demo/84_121550_000074_000000.wav differ diff --git a/demo/temp/84_121550_000074_000000.txt b/demo/temp/84_121550_000074_000000.txt new file mode 100644 index 0000000..ecac630 --- /dev/null +++ b/demo/temp/84_121550_000074_000000.txt @@ -0,0 +1 @@ +But when I had approached so near to them The common object, which the sense deceives, Lost not by distance any of its marks, \ No newline at end of file diff --git a/demo/temp/mfa_alignments/84_121550_000074_000000.csv b/demo/temp/mfa_alignments/84_121550_000074_000000.csv new file mode 100644 index 0000000..ee0750b --- /dev/null +++ b/demo/temp/mfa_alignments/84_121550_000074_000000.csv @@ -0,0 +1,109 @@ +Begin,End,Label,Type,Speaker +0.03,0.18,but,words,temp +0.18,0.32,when,words,temp +0.32,0.49,i,words,temp +0.49,0.64,had,words,temp +0.64,1.19,approached,words,temp +1.22,1.58,so,words,temp +1.58,1.9,near,words,temp +1.9,2.07,to,words,temp +2.07,2.42,them,words,temp +2.53,2.61,the,words,temp +2.61,3.01,common,words,temp +3.05,3.62,object,words,temp +3.68,3.93,which,words,temp +3.93,4.02,the,words,temp +4.02,4.34,sense,words,temp +4.34,4.97,deceives,words,temp +5.04,5.54,lost,words,temp +5.54,6.0,not,words,temp +6.0,6.14,by,words,temp +6.14,6.67,distance,words,temp +6.79,7.06,any,words,temp +7.06,7.18,of,words,temp +7.18,7.34,its,words,temp +7.34,7.87,marks,words,temp +0.03,0.06,B,phones,temp +0.06,0.09,AH1,phones,temp +0.09,0.18,T,phones,temp +0.18,0.23,W,phones,temp +0.23,0.27,EH1,phones,temp +0.27,0.32,N,phones,temp +0.32,0.49,AY1,phones,temp +0.49,0.5,HH,phones,temp +0.5,0.6,AE1,phones,temp +0.6,0.64,D,phones,temp +0.64,0.7,AH0,phones,temp +0.7,0.83,P,phones,temp +0.83,0.87,R,phones,temp +0.87,0.99,OW1,phones,temp +0.99,1.12,CH,phones,temp +1.12,1.19,T,phones,temp +1.22,1.4,S,phones,temp +1.4,1.58,OW1,phones,temp +1.58,1.7,N,phones,temp +1.7,1.84,IH1,phones,temp +1.84,1.9,R,phones,temp +1.9,2.01,T,phones,temp +2.01,2.07,AH0,phones,temp +2.07,2.13,DH,phones,temp +2.13,2.3,EH1,phones,temp +2.3,2.42,M,phones,temp +2.53,2.55,DH,phones,temp +2.55,2.61,AH0,phones,temp +2.61,2.73,K,phones,temp +2.73,2.85,AA1,phones,temp +2.85,2.9,M,phones,temp +2.9,2.95,AH0,phones,temp +2.95,3.01,N,phones,temp +3.05,3.22,AA1,phones,temp +3.22,3.27,B,phones,temp +3.27,3.34,JH,phones,temp +3.34,3.48,EH0,phones,temp +3.48,3.54,K,phones,temp +3.54,3.62,T,phones,temp +3.68,3.69,HH,phones,temp +3.69,3.76,W,phones,temp +3.76,3.8,IH1,phones,temp +3.8,3.93,CH,phones,temp +3.93,3.95,DH,phones,temp +3.95,4.02,AH0,phones,temp +4.02,4.12,S,phones,temp +4.12,4.21,EH1,phones,temp +4.21,4.27,N,phones,temp +4.27,4.34,S,phones,temp +4.34,4.42,D,phones,temp +4.42,4.45,IH0,phones,temp +4.45,4.59,S,phones,temp +4.59,4.8,IY1,phones,temp +4.8,4.87,V,phones,temp +4.87,4.97,Z,phones,temp +5.04,5.12,L,phones,temp +5.12,5.33,AO1,phones,temp +5.33,5.42,S,phones,temp +5.42,5.54,T,phones,temp +5.54,5.7,N,phones,temp +5.7,5.89,AA1,phones,temp +5.89,6.0,T,phones,temp +6.0,6.05,B,phones,temp +6.05,6.14,AY1,phones,temp +6.14,6.24,D,phones,temp +6.24,6.3,IH1,phones,temp +6.3,6.38,S,phones,temp +6.38,6.45,T,phones,temp +6.45,6.51,AH0,phones,temp +6.51,6.57,N,phones,temp +6.57,6.67,S,phones,temp +6.79,6.89,EH1,phones,temp +6.89,6.95,N,phones,temp +6.95,7.06,IY0,phones,temp +7.06,7.13,AH0,phones,temp +7.13,7.18,V,phones,temp +7.18,7.22,IH0,phones,temp +7.22,7.29,T,phones,temp +7.29,7.34,S,phones,temp +7.34,7.39,M,phones,temp +7.39,7.49,AA1,phones,temp +7.49,7.58,R,phones,temp +7.58,7.69,K,phones,temp +7.69,7.87,S,phones,temp diff --git a/edit_utils.py b/edit_utils.py new file mode 100644 index 0000000..a9683f8 --- /dev/null +++ b/edit_utils.py @@ -0,0 +1,49 @@ +def get_span(orig, new, editType): + orig_list = orig.split(" ") + new_list = new.split(" ") + + flag = False # this indicate whether the actual edit follow the specified editType + if editType == "deletion": + assert len(orig_list) > len(new_list), f"the edit type is deletion, but new is not shorter than original:\n new: {new}\n orig: {orig}" + diff = len(orig_list) - len(new_list) + for i, (o, n) in enumerate(zip(orig_list, new_list)): + if o != n: # assume the index of the first different word is the starting index of the orig_span + + orig_span = [i, i + diff - 1] # assume that the indices are starting and ending index of the deleted part + new_span = [i-1, i] # but for the new span, the starting and ending index is the two words that surround the deleted part + flag = True + break + + + elif editType == "insertion": + assert len(orig_list) < len(new_list), f"the edit type is insertion, but the new is not longer than the original:\n new: {new}\n orig: {orig}" + diff = len(new_list) - len(orig_list) + for i, (o, n) in enumerate(zip(orig_list, new_list)): + if o != n: # insertion is just the opposite of deletion + new_span = [i, i + diff - 1] # NOTE if only inserted one word, s and e will be the same + orig_span = [i-1, i] + flag = True + break + + elif editType == "substitution": + new_span = [] + orig_span = [] + for i, (o, n) in enumerate(zip(orig_list, new_list)): + if o != n: + new_span = [i] + orig_span = [i] + break + assert len(new_span) == 1 and len(orig_span) == 1, f"new_span: {new_span}, orig_span: {orig_span}" + for j, (o, n) in enumerate(zip(orig_list[::-1], new_list[::-1])): + if o != n: + new_span.append(len(new_list) - j -1) + orig_span.append(len(orig_list) - j - 1) + flag = True + break + else: + raise RuntimeError(f"editType unknown: {editType}") + + if not flag: + raise RuntimeError(f"wrong editing with the specified edit type:\n original: {orig}\n new: {new}\n, editType: {editType}") + + return orig_span, new_span \ No newline at end of file diff --git a/inference_speech_editing.ipynb b/inference_speech_editing.ipynb new file mode 100644 index 0000000..64340f7 --- /dev/null +++ b/inference_speech_editing.ipynb @@ -0,0 +1,209 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\" \n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/pyp/miniconda3/envs/voicecraft/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "# import libs\n", + "import torch\n", + "import torchaudio\n", + "\n", + "from data.tokenizer import (\n", + " AudioTokenizer,\n", + " TextTokenizer,\n", + ")\n", + "\n", + "from models import voicecraft\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# hyperparameters for inference\n", + "left_margin = 0.08\n", + "right_margin = 0.08\n", + "seed = 1\n", + "codec_audio_sr = 16000\n", + "codec_sr = 50\n", + "top_k = 0\n", + "top_p = 0.8\n", + "temperature = 1\n", + "kvcache = 0\n", + "silence_tokens = [1388,1898,131]\n", + "stop_repetition = -1 # do not stop repetition on silence\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", + "# point to the original file or record the file\n", + "# write down the transcript for the file, or run whisper to get the transcript (and you can modify it if it's not accurate), save it as a .txt file\n", + "orig_audio = \"./demo/84_121550_000074_000000.wav\"\n", + "orig_transcript = \"But when I had approached so near to them The common object, which the sense deceives, Lost not by distance any of its marks,\"\n", + "# move the audio and transcript to temp folder\n", + "temp_folder = \"./demo/temp\"\n", + "os.makedirs(temp_folder, exist_ok=True)\n", + "os.system(f\"cp {orig_audio} {temp_folder}\")\n", + "filename = os.path.splitext(orig_audio.split(\"/\")[-1])[0]\n", + "with open(f\"{temp_folder}/{filename}.txt\", \"w\") as f:\n", + " f.write(orig_transcript)\n", + "# run MFA to get the alignment\n", + "align_temp = f\"{temp_folder}/mfa_alignments\"\n", + "os.makedirs(align_temp, exist_ok=True)\n", + "os.system(f\"mfa align -j 1 --output_format csv {temp_folder} english_us_arpa english_us_arpa {align_temp}\")\n", + "# if it fail, it could be because the audio is too hard for the alignment model, increasing the beam size usually solves the issue\n", + "# os.system(f\"mfa align -j 1 --output_format csv {temp_folder} english_us_arpa english_us_arpa {align_temp} --beam 1000 --retry_beam 2000\")\n", + "audio_fn = f\"{temp_folder}/{filename}.wav\"\n", + "transcript_fn = f\"{temp_folder}/{filename}.txt\"\n", + "align_fn = f\"{align_temp}/{filename}.csv\"\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:phonemizer:words count mismatch on 300.0% of the lines (3/1)\n" + ] + } + ], + "source": [ + "editTypes_set = set(['substitution', 'insertion', 'deletion'])\n", + "# propose what do you want the target modified transcript to be\n", + "target_transcript = \"But when I saw the mirage of the lake in the distance, which the sense deceives, Lost not by distance any of its marks,\"\n", + "edit_type = \"substitution\"\n", + "assert edit_type in editTypes_set, f\"Invalid edit type {edit_type}. Must be one of {editTypes_set}.\"\n", + "\n", + "# if you want to do a second modification on top of the first one, write down the second modification (target_transcript2, type_of_modification2)\n", + "# make sure the two modification do not overlap, if they do, you need to combine them into one modification\n", + "\n", + "# run the script to turn user input to the format that the model can take\n", + "from edit_utils import get_span\n", + "orig_span, new_span = get_span(orig_transcript, target_transcript, edit_type)\n", + "if orig_span[0] > orig_span[1]:\n", + " RuntimeError(f\"example {audio_fn} failed\")\n", + "if orig_span[0] == orig_span[1]:\n", + " orig_span_save = [orig_span[0]]\n", + "else:\n", + " orig_span_save = orig_span\n", + "if new_span[0] == new_span[1]:\n", + " new_span_save = [new_span[0]]\n", + "else:\n", + " new_span_save = new_span\n", + "\n", + "orig_span_save = \",\".join([str(item) for item in orig_span_save])\n", + "new_span_save = \",\".join([str(item) for item in new_span_save])\n", + "from inference_speech_editing_scale import get_mask_interval\n", + "\n", + "start, end = get_mask_interval(align_fn, orig_span_save, edit_type)\n", + "info = torchaudio.info(audio_fn)\n", + "audio_dur = info.num_frames / info.sample_rate\n", + "morphed_span = (max(start - left_margin, 1/codec_sr), min(end + right_margin, audio_dur)) # in seconds\n", + "\n", + "# span in codec frames\n", + "mask_interval = [[round(morphed_span[0]*codec_sr), round(morphed_span[1]*codec_sr)]]\n", + "mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now\n", + "\n", + "# load model, tokenizer, and other necessary files\n", + "ckpt_fn = \"/data/scratch/pyp/exp_pyp/VoiceCraft/gigaspeech/pretrained_830M/best_bundle.pth\"\n", + "encodec_fn = \"/data/scratch/pyp/exp_pyp/audiocraft/encodec/xps/6f79c6a8/checkpoint.th\"\n", + "ckpt = torch.load(ckpt_fn, map_location=\"cpu\")\n", + "model = voicecraft.VoiceCraft(ckpt[\"config\"])\n", + "model.load_state_dict(ckpt[\"model\"])\n", + "model.to(device)\n", + "model.eval()\n", + "\n", + "phn2num = ckpt['phn2num']\n", + "\n", + "text_tokenizer = TextTokenizer(backend=\"espeak\")\n", + "audio_tokenizer = AudioTokenizer(signature=encodec_fn) # will also put the neural codec model on gpu\n", + "\n", + "# run the model to get the output\n", + "from inference_speech_editing_scale import inference_one_sample\n", + "\n", + "decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, \"codec_audio_sr\": codec_audio_sr, \"codec_sr\": codec_sr, \"silence_tokens\": silence_tokens}\n", + "orig_audio, new_audio = inference_one_sample(model, ckpt[\"config\"], phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_transcript, mask_interval, device, decode_config)\n", + " \n", + "# save segments for comparison\n", + "orig_audio, new_audio = orig_audio[0].cpu(), new_audio[0].cpu()\n", + "# logging.info(f\"length of the resynthesize orig audio: {orig_audio.shape}\")\n", + "\n", + "# output_dir\n", + "output_dir = \"./demo/generated_se\"\n", + "os.makedirs(output_dir, exist_ok=True)\n", + "\n", + "save_fn_new = f\"{output_dir}/{os.path.basename(audio_fn)[:-4]}_new_seed{seed}.wav\"\n", + "\n", + "torchaudio.save(save_fn_new, new_audio, codec_audio_sr)\n", + "\n", + "save_fn_orig = f\"{output_dir}/{os.path.basename(audio_fn)[:-4]}_orig.wav\"\n", + "if not os.path.isfile(save_fn_orig):\n", + " orig_audio, orig_sr = torchaudio.load(audio_fn)\n", + " if orig_sr != codec_audio_sr:\n", + " orig_audio = torchaudio.transforms.Resample(orig_sr, codec_audio_sr)(orig_audio)\n", + " torchaudio.save(save_fn_orig, orig_audio, codec_audio_sr)\n", + "\n", + "# if you get error importing T5 in transformers\n", + "# try \n", + "# pip uninstall Pillow\n", + "# pip install Pillow\n", + "# you are likely to get warning looks like WARNING:phonemizer:words count mismatch on 300.0% of the lines (3/1), this can be safely ignored" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "voicecraft", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/inference_speech_editing_scale.py b/inference_speech_editing_scale.py new file mode 100644 index 0000000..b034d95 --- /dev/null +++ b/inference_speech_editing_scale.py @@ -0,0 +1,226 @@ +import argparse, pickle +import logging +import os, random +import numpy as np +import torch +import torchaudio + +from data.tokenizer import ( + AudioTokenizer, + TextTokenizer, + tokenize_audio, + tokenize_text +) + +from models import voicecraft +import argparse, time, tqdm + +# this script only works for the musicgen architecture +def get_args(): + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--manifest_fn", type=str, default="path/to/eval_metadata_file") + parser.add_argument("--audio_root", type=str, default="path/to/audio_folder") + parser.add_argument("--exp_dir", type=str, default="path/to/model_folder") + parser.add_argument("--left_margin", type=float, default=0.08, help="extra space on the left to the word boundary") + parser.add_argument("--right_margin", type=float, default=0.08, help="extra space on the right to the word boundary") + parser.add_argument("--seed", type=int, default=1) + parser.add_argument("--codec_audio_sr", type=int, default=16000, help='the sample rate of audio that the codec is trained for') + parser.add_argument("--codec_sr", type=int, default=50, help='the sample rate of the codec codes') + parser.add_argument("--top_k", type=int, default=-1, help="sampling param") + parser.add_argument("--top_p", type=float, default=0.8, help="sampling param") + parser.add_argument("--temperature", type=float, default=1.0, help="sampling param") + parser.add_argument("--output_dir", type=str, default=None) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--signature", type=str, default=None, help="path to the encodec model") + parser.add_argument("--stop_repetition", type=int, default=2, help="used for inference, when the number of consecutive repetition of a token is bigger than this, stop it") + parser.add_argument("--kvcache", type=int, default=1, help='if true, use kv cache, which is 4-8x faster than without') + parser.add_argument("--silence_tokens", type=str, default="[1388,1898,131]", help="note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default") + return parser.parse_args() + +@torch.no_grad() +def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_text, mask_interval, device, decode_config): + # phonemize + text_tokens = [phn2num[phn] for phn in + tokenize_text( + text_tokenizer, text=target_text.strip() + ) if phn in phn2num + ] + text_tokens = torch.LongTensor(text_tokens).unsqueeze(0) + text_tokens_lens = torch.LongTensor([text_tokens.shape[-1]]) + + encoded_frames = tokenize_audio(audio_tokenizer, audio_fn) + original_audio = encoded_frames[0][0].transpose(2,1) # [1,T,K] + assert original_audio.ndim==3 and original_audio.shape[0] == 1 and original_audio.shape[2] == model_args.n_codebooks, original_audio.shape + logging.info(f"with direct encodec encoding before input, original audio length: {original_audio.shape[1]} codec frames, which is {original_audio.shape[1]/decode_config['codec_sr']:.2f} sec.") + + # forward + stime = time.time() + encoded_frames = model.inference( + text_tokens.to(device), + text_tokens_lens.to(device), + original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8] + mask_interval=mask_interval.unsqueeze(0).to(device), + top_k=decode_config['top_k'], + top_p=decode_config['top_p'], + temperature=decode_config['temperature'], + stop_repetition=decode_config['stop_repetition'], + kvcache=decode_config['kvcache'], + silence_tokens=eval(decode_config['silence_tokens']) if type(decode_config['silence_tokens']) == str else decode_config['silence_tokens'], + ) # output is [1,K,T] + logging.info(f"inference on one sample take: {time.time() - stime:.4f} sec.") + if type(encoded_frames) == tuple: + encoded_frames = encoded_frames[0] + logging.info(f"generated encoded_frames.shape: {encoded_frames.shape}, which is {encoded_frames.shape[-1]/decode_config['codec_sr']} sec.") + + + # decode (both original and generated) + original_sample = audio_tokenizer.decode( + [(original_audio.transpose(2,1), None)] # [1,T,8] -> [1,8,T] + ) + generated_sample = audio_tokenizer.decode( + [(encoded_frames, None)] + ) + + return original_sample, generated_sample + +def get_model(exp_dir, device=None): + with open(os.path.join(exp_dir, "args.pkl"), "rb") as f: + model_args = pickle.load(f) + + logging.info("load model weights...") + model = voicecraft.VoiceCraft(model_args) + ckpt_fn = os.path.join(exp_dir, "best_bundle.pth") + ckpt = torch.load(ckpt_fn, map_location='cpu')['model'] + phn2num = torch.load(ckpt_fn, map_location='cpu')['phn2num'] + model.load_state_dict(ckpt) + del ckpt + logging.info("done loading weights...") + if device == None: + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda:0") + model.to(device) + model.eval() + return model, model_args, phn2num + + +def get_mask_interval(ali_fn, word_span_ind, editType): + with open(ali_fn, "r") as rf: + data = [l.strip().split(",") for l in rf.readlines()] + data = data[1:] + tmp = word_span_ind.split(",") + s, e = int(tmp[0]), int(tmp[-1]) + start = None + for j, item in enumerate(data): + if j == s and item[3] == "words": + if editType == 'insertion': + start = float(item[1]) + else: + start = float(item[0]) + if j == e and item[3] == "words": + if editType == 'insertion': + end = float(item[0]) + else: + end = float(item[1]) + assert start != None + break + return (start, end) + +if __name__ == "__main__": + def seed_everything(seed): + os.environ['PYTHONHASHSEED'] = str(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + formatter = ( + "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s" + ) + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + # args.device = 'cpu' + args.allowed_repeat_tokens = eval(args.allowed_repeat_tokens) + seed_everything(args.seed) + + # load model + stime = time.time() + logging.info(f"loading model from {args.exp_dir}") + model, model_args, phn2num = get_model(args.exp_dir) + if not os.path.isfile(model_args.exp_dir): + model_args.exp_dir = args.exp_dir + logging.info(f"loading model done, took {time.time() - stime:.4f} sec") + + # setup text and audio tokenizer + text_tokenizer = TextTokenizer(backend="espeak") + audio_tokenizer = AudioTokenizer(signature=args.signature) # will also put the neural codec model on gpu + + with open(args.manifest_fn, "r") as rf: + manifest = [l.strip().split("\t") for l in rf.readlines()] + manifest = manifest[1:] + + # wav_fn txt_fn alingment_fn num_words word_span_ind + audio_fns = [] + target_texts = [] + mask_intervals = [] + edit_types = [] + new_spans = [] + orig_spans = [] + os.makedirs(args.output_dir, exist_ok=True) + if args.crop_concat: + mfa_temp = f"{args.output_dir}/mfa_temp" + os.makedirs(mfa_temp, exist_ok=True) + for item in manifest: + audio_fn = os.path.join(args.audio_root, item[0]) + temp = torchaudio.info(audio_fn) + audio_dur = temp.num_frames/temp.sample_rate + audio_fns.append(audio_fn) + target_text = item[2].split("|")[-1] + edit_types.append(item[5].split("|")) + new_spans.append(item[4].split("|")) + orig_spans.append(item[3].split("|")) + target_texts.append(target_text) # the last transcript is the target + # mi needs to be created from word_ind_span and alignment_fn, along with args.left_margin and args.right_margin + mis = [] + all_ind_intervals = item[3].split("|") + editTypes = item[5].split("|") + smaller_indx = [] + alignment_fn = os.path.join(args.audio_root, "aligned", item[0].replace(".wav", ".csv")) + if not os.path.isfile(alignment_fn): + alignment_fn = alignment_fn.replace("/aligned/", "/aligned_csv/") + assert os.path.isfile(alignment_fn), alignment_fn + for ind_inter,editType in zip(all_ind_intervals, editTypes): + # print(ind_inter) + mi = get_mask_interval(alignment_fn, ind_inter, editType) + mi = (max(mi[0] - args.left_margin, 1/args.codec_sr), min(mi[1] + args.right_margin, audio_dur)) # in seconds + mis.append(mi) + smaller_indx.append(mi[0]) + ind = np.argsort(smaller_indx) + mis = [mis[id] for id in ind] + mask_intervals.append(mis) + + + + for i, (audio_fn, target_text, mask_interval) in enumerate(tqdm.tqdm(zip(audio_fns, target_texts, mask_intervals))): + orig_mask_interval = mask_interval + mask_interval = [[round(cmi[0]*args.codec_sr), round(cmi[1]*args.codec_sr)] for cmi in mask_interval] + # logging.info(f"i: {i}, mask_interval: {mask_interval}") + mask_interval = torch.LongTensor(mask_interval) # [M,2] + orig_audio, new_audio = inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_text, mask_interval, args.device, vars(args)) + + # save segments for comparison + orig_audio, new_audio = orig_audio[0].cpu(), new_audio[0].cpu() + # logging.info(f"length of the resynthesize orig audio: {orig_audio.shape}") + + save_fn_new = f"{args.output_dir}/{os.path.basename(audio_fn)[:-4]}_new_seed{args.seed}.wav" + + torchaudio.save(save_fn_new, new_audio, args.codec_audio_sr) + + save_fn_orig = f"{args.output_dir}/{os.path.basename(audio_fn)[:-4]}_orig.wav" + if not os.path.isfile(save_fn_orig): + orig_audio, orig_sr = torchaudio.load(audio_fn) + if orig_sr != args.codec_audio_sr: + orig_audio = torchaudio.transforms.Resample(orig_sr, args.codec_audio_sr)(orig_audio) + torchaudio.save(save_fn_orig, orig_audio, args.codec_audio_sr) + diff --git a/inference_tts.ipynb b/inference_tts.ipynb new file mode 100644 index 0000000..75c25a2 --- /dev/null +++ b/inference_tts.ipynb @@ -0,0 +1,182 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\" \n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/pyp/miniconda3/envs/voicecraft/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "# import libs\n", + "import torch\n", + "import torchaudio\n", + "\n", + "from data.tokenizer import (\n", + " AudioTokenizer,\n", + " TextTokenizer,\n", + ")\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# hyperparameters for inference\n", + "left_margin = 0.08\n", + "right_margin = 0.08\n", + "seed = 1\n", + "codec_audio_sr = 16000\n", + "codec_sr = 50\n", + "top_k = 0\n", + "top_p = 0.8\n", + "temperature = 1\n", + "kvcache = 0\n", + "silence_tokens=[1388,1898,131]\n", + "# if there are long silence in the generated audio, reduce the stop_repetition to 3, 2 or even 1\n", + "stop_repetition = 2\n", + "# if there are long silence or unnaturally strecthed words, increase sample_batch_size to 2, 3 or even 4\n", + "# what this will do to the model is that the model will run sample_batch_size examples of the same audio, and pick the one that's the shortest\n", + "sample_batch_size = 1\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", + "# point to the original file or record the file\n", + "# write down the transcript for the file, or run whisper to get the transcript (and you can modify it if it's not accurate), save it as a .txt file\n", + "orig_audio = \"/home/pyp/VoiceCraft/demo/84_121550_000074_000000.wav\"\n", + "orig_transcript = \"But when I had approached so near to them The common object, which the sense deceives, Lost not by distance any of its marks,\"\n", + "\n", + "# move the audio and transcript to temp folder\n", + "temp_folder = \"/home/pyp/VoiceCraft/demo/temp\"\n", + "os.makedirs(temp_folder, exist_ok=True)\n", + "os.system(f\"cp {orig_audio} {temp_folder}\")\n", + "filename = os.path.splitext(orig_audio.split(\"/\")[-1])[0]\n", + "with open(f\"{temp_folder}/{filename}.txt\", \"w\") as f:\n", + " f.write(orig_transcript)\n", + "# run MFA to get the alignment\n", + "align_temp = f\"{temp_folder}/mfa_alignments\"\n", + "os.makedirs(align_temp, exist_ok=True)\n", + "os.system(f\"mfa align -j 1 --output_format csv {temp_folder} english_us_arpa english_us_arpa {align_temp}\")\n", + "# if the above fails, it could be because the audio is too hard for the alignment model, increasing the beam size usually solves the issue\n", + "# os.system(f\"mfa align -j 1 --output_format csv {temp_folder} english_us_arpa english_us_arpa {align_temp} --beam 1000 --retry_beam 2000\")\n", + "audio_fn = f\"{temp_folder}/{filename}.wav\"\n", + "transcript_fn = f\"{temp_folder}/{filename}.txt\"\n", + "align_fn = f\"{align_temp}/{filename}.csv\"" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Dora directory: /tmp/audiocraft_pyp\n" + ] + } + ], + "source": [ + "# take a look at demo/temp/mfa_alignment, decide which part of the audio to use as prompt\n", + "cut_off_sec = 3.01 # according to forced-alignment file, the word \"common\" stop as 3.01 sec\n", + "target_transcript = \"But when I had approached so near to them The common I cannot believe that the same model can also do text to speech synthesis as well!\"\n", + "info = torchaudio.info(audio_fn)\n", + "audio_dur = info.num_frames / info.sample_rate\n", + "\n", + "assert cut_off_sec < audio_dur, f\"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}\"\n", + "prompt_end_frame = int(cut_off_sec * info.sample_rate)\n", + "\n", + "\n", + "# # load model, tokenizer, and other necessary files\n", + "from models import voicecraft\n", + "ckpt_fn = \"/data/scratch/pyp/exp_pyp/VoiceCraft/gigaspeech/pretrained_830M/best_bundle.pth\"\n", + "encodec_fn = \"/data/scratch/pyp/exp_pyp/audiocraft/encodec/xps/6f79c6a8/checkpoint.th\"\n", + "ckpt = torch.load(ckpt_fn, map_location=\"cpu\")\n", + "model = voicecraft.VoiceCraft(ckpt[\"config\"])\n", + "model.load_state_dict(ckpt[\"model\"])\n", + "model.to(device)\n", + "model.eval()\n", + "\n", + "phn2num = ckpt['phn2num']\n", + "\n", + "text_tokenizer = TextTokenizer(backend=\"espeak\")\n", + "audio_tokenizer = AudioTokenizer(signature=encodec_fn) # will also put the neural codec model on gpu\n", + "\n", + "# run the model to get the output\n", + "decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, \"codec_audio_sr\": codec_audio_sr, \"codec_sr\": codec_sr, \"silence_tokens\": silence_tokens, \"sample_batch_size\": sample_batch_size}\n", + "from inference_tts_scale import inference_one_sample\n", + "concated_audio, gen_audio = inference_one_sample(model, ckpt[\"config\"], phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_transcript, device, decode_config, prompt_end_frame)\n", + " \n", + "# save segments for comparison\n", + "concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu()\n", + "# logging.info(f\"length of the resynthesize orig audio: {orig_audio.shape}\")\n", + "\n", + "# output_dir\n", + "output_dir = \"/home/pyp/VoiceCraft/demo/generated_tts\"\n", + "os.makedirs(output_dir, exist_ok=True)\n", + "\n", + "seg_save_fn_gen = f\"{output_dir}/{os.path.basename(audio_fn)[:-4]}_gen_seed{seed}.wav\"\n", + "seg_save_fn_concat = f\"{output_dir}/{os.path.basename(audio_fn)[:-4]}_concat_seed{seed}.wav\" \n", + "\n", + "torchaudio.save(seg_save_fn_gen, gen_audio, codec_audio_sr)\n", + "torchaudio.save(seg_save_fn_concat, concated_audio, codec_audio_sr)\n", + "\n", + "# if you get error importing T5 in transformers\n", + "# try \n", + "# pip uninstall Pillow\n", + "# pip install Pillow\n", + "# you are might get warnings like WARNING:phonemizer:words count mismatch on 300.0% of the lines (3/1), this can be safely ignored" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "voicecraft", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/inference_tts_scale.py b/inference_tts_scale.py new file mode 100644 index 0000000..2ebb78c --- /dev/null +++ b/inference_tts_scale.py @@ -0,0 +1,190 @@ +import argparse, pickle +import logging +import os, random +import numpy as np +import torch +import torchaudio + +from data.tokenizer import ( + AudioTokenizer, + TextTokenizer, + tokenize_audio, + tokenize_text +) + +from models import voicecraft +import argparse, time, tqdm + + +# this script only works for the musicgen architecture +def get_args(): + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--manifest_fn", type=str, default="path/to/eval_metadata_file") + parser.add_argument("--audio_root", type=str, default="path/to/audio_folder") + parser.add_argument("--exp_dir", type=str, default="path/to/model_folder") + parser.add_argument("--seed", type=int, default=1) + parser.add_argument("--codec_audio_sr", type=int, default=16000, help='the sample rate of audio that the codec is trained for') + parser.add_argument("--codec_sr", type=int, default=50, help='the sample rate of the codec codes') + parser.add_argument("--top_k", type=int, default=0, help="sampling param") + parser.add_argument("--top_p", type=float, default=0.8, help="sampling param") + parser.add_argument("--temperature", type=float, default=1.0, help="sampling param") + parser.add_argument("--output_dir", type=str, default=None) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--signature", type=str, default=None, help="path to the encodec model") + parser.add_argument("--crop_concat", type=int, default=0) + parser.add_argument("--stop_repetition", type=int, default=-1, help="used for inference, when the number of consecutive repetition of a token is bigger than this, stop it") + parser.add_argument("--kvcache", type=int, default=1, help='if true, use kv cache, which is 4-8x faster than without') + parser.add_argument("--sample_batch_size", type=int, default=1, help="batch size for sampling, NOTE that it's not running inference for several samples, but duplicate one input sample batch_size times, and during inference, we only return the shortest generation") + parser.add_argument("--silence_tokens", type=str, default="[1388,1898,131]", help="note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default") + return parser.parse_args() + + +@torch.no_grad() +def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_text, device, decode_config, prompt_end_frame): + # phonemize + text_tokens = [phn2num[phn] for phn in + tokenize_text( + text_tokenizer, text=target_text.strip() + ) if phn in phn2num + ] + text_tokens = torch.LongTensor(text_tokens).unsqueeze(0) + text_tokens_lens = torch.LongTensor([text_tokens.shape[-1]]) + + # encode audio + encoded_frames = tokenize_audio(audio_tokenizer, audio_fn, offset=0, num_frames=prompt_end_frame) + original_audio = encoded_frames[0][0].transpose(2,1) # [1,T,K] + assert original_audio.ndim==3 and original_audio.shape[0] == 1 and original_audio.shape[2] == model_args.n_codebooks, original_audio.shape + logging.info(f"original audio length: {original_audio.shape[1]} codec frames, which is {original_audio.shape[1]/decode_config['codec_sr']:.2f} sec.") + + # forward + stime = time.time() + if decode_config['sample_batch_size'] <= 1: + logging.info(f"running inference with batch size 1") + concat_frames, gen_frames = model.inference_tts( + text_tokens.to(device), + text_tokens_lens.to(device), + original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8] + top_k=decode_config['top_k'], + top_p=decode_config['top_p'], + temperature=decode_config['temperature'], + stop_repetition=decode_config['stop_repetition'], + kvcache=decode_config['kvcache'], + silence_tokens=eval(decode_config['silence_tokens']) if type(decode_config['silence_tokens'])==str else decode_config['silence_tokens'] + ) # output is [1,K,T] + else: + logging.info(f"running inference with batch size {decode_config['sample_batch_size']}, i.e. return the shortest among {decode_config['sample_batch_size']} generations.") + concat_frames, gen_frames = model.inference_tts_batch( + text_tokens.to(device), + text_tokens_lens.to(device), + original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8] + top_k=decode_config['top_k'], + top_p=decode_config['top_p'], + temperature=decode_config['temperature'], + stop_repetition=decode_config['stop_repetition'], + kvcache=decode_config['kvcache'], + batch_size = decode_config['sample_batch_size'], + silence_tokens=eval(decode_config['silence_tokens']) if type(decode_config['silence_tokens'])==str else decode_config['silence_tokens'] + ) # output is [1,K,T] + logging.info(f"inference on one sample take: {time.time() - stime:.4f} sec.") + + logging.info(f"generated encoded_frames.shape: {gen_frames.shape}, which is {gen_frames.shape[-1]/decode_config['codec_sr']} sec.") + + # for timestamp, codes in enumerate(gen_frames[0].transpose(1,0)): + # logging.info(f"{timestamp}: {codes.tolist()}") + # decode (both original and generated) + concat_sample = audio_tokenizer.decode( + [(concat_frames, None)] # [1,T,8] -> [1,8,T] + ) + gen_sample = audio_tokenizer.decode( + [(gen_frames, None)] + ) + + # return + return concat_sample, gen_sample + +def get_model(exp_dir, device=None): + with open(os.path.join(exp_dir, "args.pkl"), "rb") as f: + model_args = pickle.load(f) + + logging.info("load model weights...") + model = voicecraft.VoiceCraft(model_args) + ckpt_fn = os.path.join(exp_dir, "best_bundle.pth") + ckpt = torch.load(ckpt_fn, map_location='cpu')['model'] + phn2num = torch.load(ckpt_fn, map_location='cpu')['phn2num'] + model.load_state_dict(ckpt) + del ckpt + logging.info("done loading weights...") + if device == None: + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda:0") + model.to(device) + model.eval() + return model, model_args, phn2num + +if __name__ == "__main__": + def seed_everything(seed): + os.environ['PYTHONHASHSEED'] = str(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + formatter = ( + "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s" + ) + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + # args.device='cpu' + seed_everything(args.seed) + + os.makedirs(args.output_dir, exist_ok=True) + # load model + + with open(args.manifest_fn, "r") as rf: + manifest = [l.strip().split("\t") for l in rf.readlines()] + manifest = manifest[1:] + manifest = [[item[0], item[2], item[3], item[1], item[5]] for item in manifest] + + stime = time.time() + logging.info(f"loading model from {args.exp_dir}") + model, model_args, phn2num = get_model(args.exp_dir) + logging.info(f"loading model done, took {time.time() - stime:.4f} sec") + + # setup text and audio tokenizer + text_tokenizer = TextTokenizer(backend="espeak") + audio_tokenizer = AudioTokenizer(signature=args.signature) # will also put the neural codec model on gpu + + audio_fns = [] + texts = [] + prompt_end_frames = [] + new_audio_fns = [] + text_to_syn = [] + + for item in manifest: + audio_fn = os.path.join(args.audio_root, item[0]) + audio_fns.append(audio_fn) + temp = torchaudio.info(audio_fn) + prompt_end_frames.append(round(float(item[2])*temp.sample_rate)) + texts.append(item[1]) + new_audio_fns.append(item[-2]) + all_text = item[1].split(" ") + start_ind = int(item[-1].split(",")[0]) + text_to_syn.append(" ".join(all_text[start_ind:])) + + for i, (audio_fn, text, prompt_end_frame, new_audio_fn, to_syn) in enumerate(tqdm.tqdm((zip(audio_fns, texts, prompt_end_frames, new_audio_fns, text_to_syn)))): + output_expected_sr = args.codec_audio_sr + concated_audio, gen_audio = inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, text, args.device, vars(args), prompt_end_frame) + + # save segments for comparison + concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu() + if output_expected_sr != args.codec_audio_sr: + gen_audio = torchaudio.transforms.Resample(output_expected_sr, args.codec_audio_sr)(gen_audio) + concated_audio = torchaudio.transforms.Resample(output_expected_sr, args.codec_audio_sr)(concated_audio) + + seg_save_fn_gen = f"{args.output_dir}/gen_{new_audio_fn[:-4]}_{i}_seed{args.seed}.wav" + seg_save_fn_concat = f"{args.output_dir}/concat_{new_audio_fn[:-4]}_{i}_seed{args.seed}.wav" + + torchaudio.save(seg_save_fn_gen, gen_audio, args.codec_audio_sr) + torchaudio.save(seg_save_fn_concat, concated_audio, args.codec_audio_sr) \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..0d3fac5 --- /dev/null +++ b/main.py @@ -0,0 +1,45 @@ +from pathlib import Path +import torch +import pickle +import argparse +import logging +import torch.distributed as dist +from config import MyParser +from steps import trainer + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s" + ) + logging.basicConfig(format=formatter, level=logging.INFO) + + torch.cuda.empty_cache() + args = MyParser().parse_args() + logging.info(args) + exp_dir = Path(args.exp_dir) + exp_dir.mkdir(exist_ok=True, parents=True) + logging.info(f"exp_dir: {str(exp_dir)}") + + if args.resume: + resume = args.resume + assert(bool(args.exp_dir)) + with open("%s/args.pkl" % args.exp_dir, "rb") as f: + old_args = pickle.load(f) + new_args = vars(args) + old_args = vars(old_args) + for key in new_args: + if key not in old_args or old_args[key] != new_args[key]: + old_args[key] = new_args[key] + args = argparse.Namespace(**old_args) + args.resume = resume + else: + with open("%s/args.pkl" % args.exp_dir, "wb") as f: + pickle.dump(args, f) + + dist.init_process_group(backend='nccl', init_method='env://') + rank = dist.get_rank() + world_size = dist.get_world_size() + torch.cuda.set_device(rank) + my_trainer = trainer.Trainer(args, world_size, rank) + my_trainer.train() \ No newline at end of file diff --git a/models/codebooks_patterns.py b/models/codebooks_patterns.py new file mode 100644 index 0000000..24c6319 --- /dev/null +++ b/models/codebooks_patterns.py @@ -0,0 +1,538 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from collections import namedtuple +from dataclasses import dataclass +from functools import lru_cache +import logging +import typing as tp + +from abc import ABC, abstractmethod +import torch + +LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index) +PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates + + +@dataclass +class Pattern: + """Base implementation of a pattern over a sequence with multiple codebooks. + + The codebook pattern consists in a layout, defining for each sequence step + the list of coordinates of each codebook timestep in the resulting interleaved sequence. + The first item of the pattern is always an empty list in order to properly insert a special token + to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern + and ``timesteps`` the number of timesteps corresponding to the original sequence. + + The pattern provides convenient methods to build and revert interleaved sequences from it: + ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T] + to the interleaved sequence of shape [B, K, S] applying the pattern, with S being the batch size, + K being the number of codebooks, T the number of original timesteps and S the number of sequence steps + for the output sequence. The unfilled positions are replaced with a special token and the built sequence + is returned along with a mask indicating valid tokens. + ``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment + of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask + to fill and specify invalid positions if needed. + See the dedicated methods for more details. + """ + # Pattern layout, for each sequence step, we have a list of coordinates + # corresponding to the original codebook timestep and position. + # The first list is always an empty list in order to properly insert + # a special token to start with. + layout: PatternLayout + timesteps: int + n_q: int + + def __post_init__(self): + assert len(self.layout) > 0 + assert self.layout[0] == [] + self._validate_layout() + self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes) + self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes) + # logging.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout)) + + def _validate_layout(self): + """Runs checks on the layout to ensure a valid pattern is defined. + A pattern is considered invalid if: + - Multiple timesteps for a same codebook are defined in the same sequence step + - The timesteps for a given codebook are not in ascending order as we advance in the sequence + (this would mean that we have future timesteps before past timesteps). + """ + q_timesteps = {q: 0 for q in range(self.n_q)} + for s, seq_coords in enumerate(self.layout): + if len(seq_coords) > 0: + qs = set() + for coord in seq_coords: + qs.add(coord.q) + last_q_timestep = q_timesteps[coord.q] + assert coord.t >= last_q_timestep, \ + f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}" + q_timesteps[coord.q] = coord.t + # each sequence step contains at max 1 coordinate per codebook + assert len(qs) == len(seq_coords), \ + f"Multiple entries for a same codebook are found at step {s}" + + @property + def num_sequence_steps(self): + return len(self.layout) - 1 + + @property + def max_delay(self): + max_t_in_seq_coords = 0 + for seq_coords in self.layout[1:]: + for coords in seq_coords: + max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1) + return max_t_in_seq_coords - self.timesteps + + @property + def valid_layout(self): + valid_step = len(self.layout) - self.max_delay + return self.layout[:valid_step] + + def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None): + """Get codebook coordinates in the layout that corresponds to the specified timestep t + and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step + and the actual codebook coordinates. + """ + assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps" + if q is not None: + assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks" + coords = [] + for s, seq_codes in enumerate(self.layout): + for code in seq_codes: + if code.t == t and (q is None or code.q == q): + coords.append((s, code)) + return coords + + def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]: + return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)] + + def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]: + steps_with_timesteps = self.get_steps_with_timestep(t, q) + return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None + + def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool, + device: tp.Union[torch.device, str] = 'cpu'): + """Build scatter indexes corresponding to the pattern, up to the provided sequence_steps. + + Args: + timesteps (int): Maximum number of timesteps steps to consider. + keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps. + device (Union[torch.device, str]): Device for created tensors. + Returns: + indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S]. + mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S]. + """ + assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}" + assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern" + # use the proper layout based on whether we limit ourselves to valid steps only or not, + # note that using the valid_layout will result in a truncated sequence up to the valid steps + ref_layout = self.valid_layout if keep_only_valid_steps else self.layout + # single item indexing being super slow with pytorch vs. numpy, so we use numpy here + indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy() + mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy() + # fill indexes with last sequence step value that will correspond to our special token + # the last value is n_q * timesteps as we have flattened z and append special token as the last token + # which will correspond to the index: n_q * timesteps + indexes[:] = n_q * timesteps + # iterate over the pattern and fill scattered indexes and mask + for s, sequence_coords in enumerate(ref_layout): + for coords in sequence_coords: + if coords.t < timesteps: + indexes[coords.q, s] = coords.t + coords.q * timesteps + mask[coords.q, s] = 1 + indexes = torch.from_numpy(indexes).to(device) + mask = torch.from_numpy(mask).to(device) + return indexes, mask + + def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False): + """Build sequence corresponding to the pattern from the input tensor z. + The sequence is built using up to sequence_steps if specified, and non-pattern + coordinates are filled with the special token. + + Args: + z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T]. + special_token (int): Special token used to fill non-pattern coordinates in the new sequence. + keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps. + Steps that are beyond valid steps will be replaced by the special_token in that case. + Returns: + values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S + corresponding either to the sequence_steps if provided, otherwise to the length of the pattern. + indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S]. + mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S]. + """ + B, K, T = z.shape + indexes, mask = self._build_pattern_sequence_scatter_indexes( + T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device) + ) + z = z.view(B, -1) + # we append the special token as the last index of our flattened z tensor + z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1) + values = z[:, indexes.view(-1)] + values = values.view(B, K, indexes.shape[-1]) + return values, indexes, mask + + def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int, + keep_only_valid_steps: bool = False, + is_model_output: bool = False, + device: tp.Union[torch.device, str] = 'cpu'): + """Builds scatter indexes required to retrieve the original multi-codebook sequence + from interleaving pattern. + + Args: + sequence_steps (int): Sequence steps. + n_q (int): Number of codebooks. + keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps. + Steps that are beyond valid steps will be replaced by the special_token in that case. + is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not. + device (Union[torch.device, str]): Device for created tensors. + Returns: + torch.Tensor: Indexes for reconstructing the output, of shape [K, T]. + mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T]. + """ + ref_layout = self.valid_layout if keep_only_valid_steps else self.layout + # TODO(jade): Do we want to further truncate to only valid timesteps here as well? + timesteps = self.timesteps + assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}" + assert sequence_steps <= len(ref_layout), \ + f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}" + + # ensure we take the appropriate indexes to keep the model output from the first special token as well + if is_model_output: + ref_layout = ref_layout[1:] + + # single item indexing being super slow with pytorch vs. numpy, so we use numpy here + indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy() + mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy() + # fill indexes with last sequence step value that will correspond to our special token + indexes[:] = n_q * sequence_steps + for s, sequence_codes in enumerate(ref_layout): + if s < sequence_steps: + for code in sequence_codes: + if code.t < timesteps: + indexes[code.q, code.t] = s + code.q * sequence_steps + mask[code.q, code.t] = 1 + indexes = torch.from_numpy(indexes).to(device) + mask = torch.from_numpy(mask).to(device) + return indexes, mask + + def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False): + """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving. + The sequence is reverted using up to timesteps if specified, and non-pattern coordinates + are filled with the special token. + + Args: + s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S]. + special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence. + Returns: + values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T + corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise. + indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T]. + mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T]. + """ + B, K, S = s.shape + indexes, mask = self._build_reverted_sequence_scatter_indexes( + S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device) + ) + s = s.view(B, -1) + # we append the special token as the last index of our flattened z tensor + s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1) + values = s[:, indexes.view(-1)] + values = values.view(B, K, indexes.shape[-1]) + return values, indexes, mask + + def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False): + """Revert model logits obtained on a sequence built from the pattern + back to a tensor matching the original sequence. + + This method is similar to ``revert_pattern_sequence`` with the following specificities: + 1. It is designed to work with the extra cardinality dimension + 2. We return the logits for the first sequence item that matches the special_token and + which matching target in the original sequence is the first item of the sequence, + while we skip the last logits as there is no matching target + """ + B, card, K, S = logits.shape + indexes, mask = self._build_reverted_sequence_scatter_indexes( + S, K, keep_only_valid_steps, is_model_output=True, device=logits.device + ) + logits = logits.reshape(B, card, -1) + # we append the special token as the last index of our flattened z tensor + logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S] + values = logits[:, :, indexes.view(-1)] + values = values.view(B, card, K, indexes.shape[-1]) + return values, indexes, mask + + +class CodebooksPatternProvider(ABC): + """Abstraction around providing pattern for interleaving codebooks. + + The CodebooksPatternProvider abstraction allows to implement various strategies to + define interleaving pattern of sequences composed of multiple codebooks. For a given + number of codebooks `n_q`, the pattern provider can generate a specified pattern + corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern + can be used to construct a new sequence from the original codes respecting the specified + pattern. The pattern is defined as a list of list of code coordinates, code coordinate + being a tuple with the original timestep and codebook to build the new sequence. + Note that all patterns must start with an empty list that is then used to insert a first + sequence step of special tokens in the newly generated sequence. + + Args: + n_q (int): number of codebooks. + cached (bool): if True, patterns for a given length are cached. In general + that should be true for efficiency reason to avoid synchronization points. + """ + def __init__(self, n_q: int, cached: bool = True): + assert n_q > 0 + self.n_q = n_q + self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore + + @abstractmethod + def get_pattern(self, timesteps: int) -> Pattern: + """Builds pattern with specific interleaving between codebooks. + + Args: + timesteps (int): Total numer of timesteps. + """ + raise NotImplementedError() + + +class DelayedPatternProvider(CodebooksPatternProvider): + """Provider for delayed pattern across delayed codebooks. + Codebooks are delayed in the sequence and sequence steps will contain codebooks + from different timesteps. + + Example: + Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence: + [[1, 2, 3, 4], + [1, 2, 3, 4], + [1, 2, 3, 4]] + The resulting sequence obtained from the returned pattern is: + [[S, 1, 2, 3, 4], + [S, S, 1, 2, 3], + [S, S, S, 1, 2]] + (with S being a special token) + + Args: + n_q (int): Number of codebooks. + delays (Optional[List[int]]): Delay for each of the codebooks. + If delays not defined, each codebook is delayed by 1 compared to the previous one. + flatten_first (int): Flatten the first N timesteps. + empty_initial (int): Prepend with N empty list of coordinates. + """ + def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None, + flatten_first: int = 0, empty_initial: int = 0): + super().__init__(n_q) + if delays is None: + delays = list(range(n_q)) + self.delays = delays + self.flatten_first = flatten_first + self.empty_initial = empty_initial + assert len(self.delays) == self.n_q + assert sorted(self.delays) == self.delays + + def get_pattern(self, timesteps: int) -> Pattern: + out: PatternLayout = [[]] + max_delay = max(self.delays) + if self.empty_initial: + out += [[] for _ in range(self.empty_initial)] + if self.flatten_first: + for t in range(min(timesteps, self.flatten_first)): + for q in range(self.n_q): + out.append([LayoutCoord(t, q)]) + for t in range(self.flatten_first, timesteps + max_delay): + v = [] + for q, delay in enumerate(self.delays): + t_for_q = t - delay + if t_for_q >= self.flatten_first: + v.append(LayoutCoord(t_for_q, q)) + out.append(v) + return Pattern(out, n_q=self.n_q, timesteps=timesteps) + + +class ParallelPatternProvider(DelayedPatternProvider): + """Provider for parallel pattern across codebooks. + This pattern provider is a special case of the delayed pattern with actually no delay, + hence delays=repeat(0, n_q). + + Args: + n_q (int): Number of codebooks. + """ + def __init__(self, n_q: int): + super().__init__(n_q, [0] * n_q) + + +class UnrolledPatternProvider(CodebooksPatternProvider): + """Provider for unrolling codebooks pattern. + This pattern provider enables to represent the codebook flattened completely or only to some extend + while also specifying a given delay between the flattened codebooks representation, allowing to + unroll the codebooks in the sequence. + + Example: + 1. Flattening of the codebooks. + By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q), + taking n_q = 3 and timesteps = 4: + [[1, 2, 3, 4], + [1, 2, 3, 4], + [1, 2, 3, 4]] + will result into: + [[S, S, 1, S, S, 2, S, S, 3, S, S, 4], + [S, 1, S, S, 2, S, S, 3, S, S, 4, S], + [1, S, S, 2, S, S, 3, S, S, 4, S, S]] + 2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step + for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example + taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]: + [[1, 2, 3, 4], + [1, 2, 3, 4], + [1, 2, 3, 4]] + will result into: + [[S, 1, S, S, 2, S, S, 3, S, S, 4, S], + [S, 1, S, S, 2, S, S, 3, S, S, 4, S], + [1, S, S, 2, S, S, 3, S, S, 4, S, S]] + 3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks + allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the + same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1] + and delays = [0, 3, 3]: + [[1, 2, 3, 4], + [1, 2, 3, 4], + [1, 2, 3, 4]] + will result into: + [[S, S, S, 1, S, 2, S, 3, S, 4], + [S, S, S, 1, S, 2, S, 3, S, 4], + [1, 2, 3, S, 4, S, 5, S, 6, S]] + + Args: + n_q (int): Number of codebooks. + flattening (Optional[List[int]]): Flattening schema over the codebooks. If not defined, + the codebooks will be flattened to 1 codebook per step, meaning that the sequence will + have n_q extra steps for each timestep. + delays (Optional[List[int]]): Delay for each of the codebooks. If not defined, + no delay is added and therefore will default to [0] * ``n_q``. + Note that two codebooks that will be flattened to the same inner step + should have the same delay, otherwise the pattern is considered as invalid. + """ + FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay']) + + def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None, + delays: tp.Optional[tp.List[int]] = None): + super().__init__(n_q) + if flattening is None: + flattening = list(range(n_q)) + if delays is None: + delays = [0] * n_q + assert len(flattening) == n_q + assert len(delays) == n_q + assert sorted(flattening) == flattening + assert sorted(delays) == delays + self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening) + self.max_delay = max(delays) + + def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]): + """Build a flattened codebooks representation as a dictionary of inner step + and the actual codebook indices corresponding to the flattened codebook. For convenience, we + also store the delay associated to the flattened codebook to avoid maintaining an extra mapping. + """ + flattened_codebooks: dict = {} + for q, (inner_step, delay) in enumerate(zip(flattening, delays)): + if inner_step not in flattened_codebooks: + flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay) + else: + flat_codebook = flattened_codebooks[inner_step] + assert flat_codebook.delay == delay, ( + "Delay and flattening between codebooks is inconsistent: ", + "two codebooks flattened to the same position should have the same delay." + ) + flat_codebook.codebooks.append(q) + flattened_codebooks[inner_step] = flat_codebook + return flattened_codebooks + + @property + def _num_inner_steps(self): + """Number of inner steps to unroll between timesteps in order to flatten the codebooks. + """ + return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1 + + def num_virtual_steps(self, timesteps: int) -> int: + return timesteps * self._num_inner_steps + 1 + + def get_pattern(self, timesteps: int) -> Pattern: + """Builds pattern for delay across codebooks. + + Args: + timesteps (int): Total numer of timesteps. + """ + # the PatternLayout is built as a tuple of sequence position and list of coordinates + # so that it can be reordered properly given the required delay between codebooks of given timesteps + indexed_out: list = [(-1, [])] + max_timesteps = timesteps + self.max_delay + for t in range(max_timesteps): + # for each timestep, we unroll the flattened codebooks, + # emitting the sequence step with the corresponding delay + for step in range(self._num_inner_steps): + if step in self._flattened_codebooks: + # we have codebooks at this virtual step to emit + step_codebooks = self._flattened_codebooks[step] + t_for_q = t + step_codebooks.delay + coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks] + if t_for_q < max_timesteps and t < max_timesteps: + indexed_out.append((t_for_q, coords)) + else: + # there is no codebook in this virtual step so we emit an empty list + indexed_out.append((t, [])) + out = [coords for _, coords in sorted(indexed_out)] + return Pattern(out, n_q=self.n_q, timesteps=timesteps) + + +class VALLEPattern(CodebooksPatternProvider): + """Almost VALL-E style pattern. We futher allow some delays for the + codebooks other than the first one. + + Args: + n_q (int): Number of codebooks. + delays (Optional[List[int]]): Delay for each of the codebooks. + If delays not defined, each codebook is delayed by 1 compared to the previous one. + """ + def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None): + super().__init__(n_q) + if delays is None: + delays = [0] * (n_q - 1) + self.delays = delays + assert len(self.delays) == self.n_q - 1 + assert sorted(self.delays) == self.delays + + def get_pattern(self, timesteps: int) -> Pattern: + out: PatternLayout = [[]] + for t in range(timesteps): + out.append([LayoutCoord(t, 0)]) + max_delay = max(self.delays) + for t in range(timesteps + max_delay): + v = [] + for q, delay in enumerate(self.delays): + t_for_q = t - delay + if t_for_q >= 0: + v.append(LayoutCoord(t_for_q, q + 1)) + out.append(v) + return Pattern(out, n_q=self.n_q, timesteps=timesteps) + + +class MusicLMPattern(CodebooksPatternProvider): + """Almost MusicLM style pattern. This is equivalent to full flattening + but in a different order. + + Args: + n_q (int): Number of codebooks. + group_by (int): Number of codebooks to group together. + """ + def __init__(self, n_q: int, group_by: int = 2): + super().__init__(n_q) + self.group_by = group_by + + def get_pattern(self, timesteps: int) -> Pattern: + out: PatternLayout = [[]] + for offset in range(0, self.n_q, self.group_by): + for t in range(timesteps): + for q in range(offset, offset + self.group_by): + out.append([LayoutCoord(t, q)]) + return Pattern(out, n_q=self.n_q, timesteps=timesteps) \ No newline at end of file diff --git a/models/modules/__init__.py b/models/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/modules/activation.py b/models/modules/activation.py new file mode 100644 index 0000000..cea9b01 --- /dev/null +++ b/models/modules/activation.py @@ -0,0 +1,653 @@ +# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py, modified by Puyuan Peng, 2024 +from typing import Optional, Tuple + +import torch +from torch import Tensor +from torch.nn import Linear, Module +from torch.nn import functional as F +from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ +from torch.nn.modules.linear import NonDynamicallyQuantizableLinear +from torch.nn.parameter import Parameter +import logging +from typing import Callable, List, Optional, Tuple, Union +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from torch.types import _dtype as DType +else: + # The JIT doesn't understand Union, nor torch.dtype here + DType = int + +def _canonical_mask( + mask: Optional[Tensor], + mask_name: str, + other_type: Optional[DType], + other_name: str, + target_type: DType, + check_other: bool = True, +) -> Optional[Tensor]: + + if mask is not None: + _mask_dtype = mask.dtype + _mask_is_float = torch.is_floating_point(mask) + if _mask_dtype != torch.bool and not _mask_is_float: + raise AssertionError( + f"only bool and floating types of {mask_name} are supported") + if check_other and other_type is not None: + if _mask_dtype != other_type: + warnings.warn( + f"Support for mismatched {mask_name} and {other_name} " + "is deprecated. Use same type for both instead." + ) + if not _mask_is_float: + mask = ( + torch.zeros_like(mask, dtype=target_type) + .masked_fill_(mask, float("-inf")) + ) + return mask + +def _in_projection_packed( + q: Tensor, + k: Tensor, + v: Tensor, + w: Tensor, + b: Optional[Tensor] = None, +) -> List[Tensor]: + r""" + Performs the in-projection step of the attention operation, using packed weights. + Output is a triple containing projection tensors for query, key and value. + + Args: + q, k, v: query, key and value tensors to be projected. For self-attention, + these are typically the same tensor; for encoder-decoder attention, + k and v are typically the same tensor. (We take advantage of these + identities for performance if they are present.) Regardless, q, k and v + must share a common embedding dimension; otherwise their shapes may vary. + w: projection weights for q, k and v, packed into a single tensor. Weights + are packed along dimension 0, in q, k, v order. + b: optional projection biases for q, k and v, packed into a single tensor + in q, k, v order. + + Shape: + Inputs: + - q: :math:`(..., E)` where E is the embedding dimension + - k: :math:`(..., E)` where E is the embedding dimension + - v: :math:`(..., E)` where E is the embedding dimension + - w: :math:`(E * 3, E)` where E is the embedding dimension + - b: :math:`E * 3` where E is the embedding dimension + + Output: + - in output list :math:`[q', k', v']`, each output tensor will have the + same shape as the corresponding input tensor. + """ + E = q.size(-1) + if k is v: + if q is k: + # self-attention + proj = F.linear(q, w, b) + # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk() + proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous() + return proj[0], proj[1], proj[2] + else: + # encoder-decoder attention + w_q, w_kv = w.split([E, E * 2]) + if b is None: + b_q = b_kv = None + else: + b_q, b_kv = b.split([E, E * 2]) + q_proj = F.linear(q, w_q, b_q) + kv_proj = F.linear(k, w_kv, b_kv) + # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk() + kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous() + return (q_proj, kv_proj[0], kv_proj[1]) + else: + w_q, w_k, w_v = w.chunk(3) + if b is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = b.chunk(3) + return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v) + +def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]: + if input is None: + return None + elif isinstance(input, torch.Tensor): + return input.dtype + raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor") +class MultiheadAttention(Module): + r"""Allows the model to jointly attend to information + from different representation subspaces as described in the paper: + `Attention Is All You Need `_. + + Multi-Head Attention is defined as: + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + + where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. + + ``forward()`` will use a special optimized implementation if all of the following + conditions are met: + + - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This + restriction will be loosened in the future.) + - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad`` + - training is disabled (using ``.eval()``) + - dropout is 0 + - ``add_bias_kv`` is ``False`` + - ``add_zero_attn`` is ``False`` + - ``batch_first`` is ``True`` and the input is batched + - ``kdim`` and ``vdim`` are equal to ``embed_dim`` + - at most one of ``key_padding_mask`` or ``attn_mask`` is passed + - if a `NestedTensor `_ is passed, neither ``key_padding_mask`` + nor ``attn_mask`` is passed + + If the optimized implementation is in use, a + `NestedTensor `_ can be passed for + ``query``/``key``/``value`` to represent padding more efficiently than using a + padding mask. In this case, a `NestedTensor `_ + will be returned, and an additional speedup proportional to the fraction of the input + that is padding can be expected. + + Args: + embed_dim: Total dimension of the model. + num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split + across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). + dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout). + bias: If specified, adds bias to input / output projection layers. Default: ``True``. + add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``. + add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1. + Default: ``False``. + kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). + vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + + Examples:: + + >>> # xdoctest: +SKIP + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + + """ + __constants__ = ["batch_first"] + bias_k: Optional[torch.Tensor] + bias_v: Optional[torch.Tensor] + + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + linear1_cls=Linear, + linear2_cls=Linear, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(MultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = ( + self.kdim == embed_dim and self.vdim == embed_dim + ) + + self.num_heads = num_heads + self.dropout = dropout + self.batch_first = batch_first + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + if add_bias_kv: + self.bias_k = Parameter( + torch.empty((1, 1, embed_dim), **factory_kwargs) + ) + self.bias_v = Parameter( + torch.empty((1, 1, embed_dim), **factory_kwargs) + ) + else: + self.bias_k = self.bias_v = None + + if linear1_cls == Linear: + if not self._qkv_same_embed_dim: + self.q_proj_weight = Parameter( + torch.empty((embed_dim, embed_dim), **factory_kwargs) + ) + self.k_proj_weight = Parameter( + torch.empty((embed_dim, self.kdim), **factory_kwargs) + ) + self.v_proj_weight = Parameter( + torch.empty((embed_dim, self.vdim), **factory_kwargs) + ) + self.register_parameter("in_proj_weight", None) + else: + # go down this route with voicecraft + self.in_proj_weight = Parameter( + torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) + ) + self.register_parameter("q_proj_weight", None) + self.register_parameter("k_proj_weight", None) + self.register_parameter("v_proj_weight", None) + + if bias: # True by default + self.in_proj_bias = Parameter( + torch.empty(3 * embed_dim, **factory_kwargs) + ) + else: + self.register_parameter("in_proj_bias", None) + self.out_proj = NonDynamicallyQuantizableLinear( + embed_dim, embed_dim, bias=bias, **factory_kwargs + ) + + self._reset_parameters() + else: + if not self._qkv_same_embed_dim: + raise NotImplementedError + else: + self.in_proj_linear = linear1_cls( + embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs + ) + self.in_proj_weight = self.in_proj_linear.weight + + self.register_parameter("q_proj_weight", None) + self.register_parameter("k_proj_weight", None) + self.register_parameter("v_proj_weight", None) + + if bias: + self.in_proj_bias = self.in_proj_linear.bias + else: + self.register_parameter("in_proj_bias", None) + + self.out_proj = linear2_cls( + embed_dim, embed_dim, bias=bias, **factory_kwargs + ) + + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + self.add_zero_attn = add_zero_attn + + def _reset_parameters(self): + if self._qkv_same_embed_dim: + xavier_uniform_(self.in_proj_weight) + else: + xavier_uniform_(self.q_proj_weight) + xavier_uniform_(self.k_proj_weight) + xavier_uniform_(self.v_proj_weight) + + if self.in_proj_bias is not None: + constant_(self.in_proj_bias, 0.0) + constant_(self.out_proj.bias, 0.0) + + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + def __setstate__(self, state): + # Support loading old MultiheadAttention checkpoints generated by v1.1.0 + if "_qkv_same_embed_dim" not in state: + state["_qkv_same_embed_dim"] = True + + super(MultiheadAttention, self).__setstate__(state) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True, + past: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` + or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, + :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. + Queries are compared against key-value pairs to produce the output. + See "Attention Is All You Need" for more details. + key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` + or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, + :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. + See "Attention Is All You Need" for more details. + value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when + ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source + sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. + See "Attention Is All You Need" for more details. + key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` + to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. + Binary and byte masks are supported. + For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for + the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. + need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. + Default: ``True``. + attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape + :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, + :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be + broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. + Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the + corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the + corresponding position is not allowed to attend. For a float mask, the mask values will be added to + the attention weight. + average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across + heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an + effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) + + Outputs: + - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, + :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, + where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the + embedding dimension ``embed_dim``. + - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, + returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or + :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and + :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per + head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. + + .. note:: + `batch_first` argument is ignored for unbatched inputs. + """ + is_batched = query.dim() == 3 + if key_padding_mask is not None: + _kpm_dtype = key_padding_mask.dtype + if _kpm_dtype != torch.bool and not torch.is_floating_point( + key_padding_mask + ): + raise AssertionError( + "only bool and floating types of key_padding_mask are supported" + ) + why_not_fast_path = "" + if not is_batched: + why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}" + elif query is not key or key is not value: + # When lifting this restriction, don't forget to either + # enforce that the dtypes all match or test cases where + # they don't! + why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" + elif ( + self.in_proj_bias is not None + and query.dtype != self.in_proj_bias.dtype + ): + why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" + elif ( + self.in_proj_weight is not None + and query.dtype != self.in_proj_weight.dtype + ): + # this case will fail anyway, but at least they'll get a useful error message. + why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" + elif self.training: + why_not_fast_path = "training is enabled" + elif not self.batch_first: + why_not_fast_path = "batch_first was not True" + elif self.bias_k is not None: + why_not_fast_path = "self.bias_k was not None" + elif self.bias_v is not None: + why_not_fast_path = "self.bias_v was not None" + elif self.dropout: + why_not_fast_path = f"dropout was {self.dropout}, required zero" + elif self.add_zero_attn: + why_not_fast_path = "add_zero_attn was enabled" + elif not self._qkv_same_embed_dim: + why_not_fast_path = "_qkv_same_embed_dim was not True" + elif attn_mask is not None: + why_not_fast_path = "attn_mask was not None" + elif query.is_nested and key_padding_mask is not None: + why_not_fast_path = ( + "key_padding_mask is not supported with NestedTensor input" + ) + elif self.num_heads % 2 == 1: + why_not_fast_path = "num_heads is odd" + elif torch.is_autocast_enabled(): + why_not_fast_path = "autocast is enabled" + + if not why_not_fast_path: + tensor_args = ( + query, + key, + value, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj.weight, + self.out_proj.bias, + ) + # We have to use list comprehensions below because TorchScript does not support + # generator expressions. + if torch.overrides.has_torch_function(tensor_args): + why_not_fast_path = "some Tensor argument has_torch_function" + elif not all( + [ + (x is None or x.is_cuda or "cpu" in str(x.device)) + for x in tensor_args + ] + ): + why_not_fast_path = ( + "some Tensor argument is neither CUDA nor CPU" + ) + elif torch.is_grad_enabled() and any( + [x is not None and x.requires_grad for x in tensor_args] + ): + why_not_fast_path = ( + "grad is enabled and at least one of query or the " + "input/output projection weights or biases requires_grad" + ) + if not why_not_fast_path: + return torch._native_multi_head_attention( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj.weight, + self.out_proj.bias, + key_padding_mask + if key_padding_mask is not None + else attn_mask, + need_weights, + average_attn_weights, + 1 + if key_padding_mask is not None + else 0 + if attn_mask is not None + else None, + ) + + any_nested = query.is_nested or key.is_nested or value.is_nested + assert not any_nested, ( + "MultiheadAttention does not support NestedTensor outside of its fast path. " + + f"The fast path was not hit because {why_not_fast_path}" + ) + + if self.batch_first and is_batched: + # make sure that the transpose op does not affect the "is" property + if key is value: + if query is key: + query = key = value = query.transpose(1, 0) + else: + query, key = [x.transpose(1, 0) for x in (query, key)] + value = key + else: + query, key, value = [ + x.transpose(1, 0) for x in (query, key, value) + ] + + if not self._qkv_same_embed_dim: + attn_output, attn_output_weights = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, + k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight, + average_attn_weights=average_attn_weights, + ) + else: + # re-write the self.attention here, to get k, v cache + tgt_len, bsz, embed_dim = query.shape + src_len, _, _ = key.shape + num_heads = self.num_heads + key_padding_mask = _canonical_mask( + mask=key_padding_mask, + mask_name="key_padding_mask", + other_type=_none_or_dtype(attn_mask), + other_name="attn_mask", + target_type=query.dtype + ) + attn_mask = _canonical_mask( + mask=attn_mask, + mask_name="attn_mask", + other_type=None, + other_name="", + target_type=query.dtype, + check_other=False, + ) + head_dim = self.embed_dim // self.num_heads + assert head_dim * self.num_heads == self.embed_dim, f"embed_dim {self.embed_dim} not divisible by num_heads {self.num_heads}" + assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}" + q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias) + # k_present, v_present = k, v + + # + # reshape q, k, v for multihead attention and make em batch first + # + + q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) + k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) + v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) # (bsz * num_heads, src_len, head_dim) + src_len = k.size(1) + if past is not None and past.ndim > 2: + expected_src_len = src_len + past[0].shape[-2] + else: + expected_src_len = src_len + + + # ensure attn_mask's dim is 3 + if attn_mask.dim() == 2: + correct_2d_size = (tgt_len, expected_src_len) + if attn_mask.shape != correct_2d_size: + raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.") + attn_mask = attn_mask.unsqueeze(0) + elif attn_mask.dim() == 3: + correct_3d_size = (bsz * num_heads, tgt_len, expected_src_len) + if attn_mask.shape != correct_3d_size: + raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.") + else: + raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported") + + if key_padding_mask is not None: + assert key_padding_mask.shape == (bsz, expected_src_len), \ + f"expecting key_padding_mask shape of {(bsz, expected_src_len)}, but got {key_padding_mask.shape}" + key_padding_mask = key_padding_mask.view(bsz, 1, 1, expected_src_len). \ + expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, expected_src_len) + if attn_mask is None: + attn_mask = key_padding_mask + else: + attn_mask = attn_mask + key_padding_mask + + if not self.training: + dropout_p = 0.0 + else: + dropout_p = self.dropout + + if need_weights: + raise NotImplementedError("need_weights not implemented for voicecraft") + # B, Nt, E = q.shape + # q_scaled = q / math.sqrt(E) + + # assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights" + + # if attn_mask is not None: + # attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1)) + # else: + # attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1)) + # attn_output_weights = softmax(attn_output_weights, dim=-1) + # if dropout_p > 0.0: + # attn_output_weights = dropout(attn_output_weights, p=dropout_p) + + # attn_output = torch.bmm(attn_output_weights, v) + + # attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim) + # attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + # attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) + + # # optionally average attention weights over heads + # attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + # if average_attn_weights: + # attn_output_weights = attn_output_weights.mean(dim=1) + + # if not is_batched: + # # squeeze the output if input was unbatched + # attn_output = attn_output.squeeze(1) + # attn_output_weights = attn_output_weights.squeeze(0) + # return attn_output, attn_output_weights + else: + # attn_mask can be either (L,S) or (N*num_heads, L, S) + # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S) + # in order to match the input for SDPA of (N, num_heads, L, S) + if attn_mask is not None: + if attn_mask.size(0) == 1 and attn_mask.dim() == 3: + attn_mask = attn_mask.unsqueeze(0) + else: + attn_mask = attn_mask.view(bsz, num_heads, -1, expected_src_len) + + q = q.view(bsz, num_heads, tgt_len, head_dim) + k = k.view(bsz, num_heads, src_len, head_dim) + v = v.view(bsz, num_heads, src_len, head_dim) + # logging.info(f"shape of past: {past.shape}") + if past is not None: + present = torch.stack([k, v], dim=0) # (2, bsz, num_heads, src_len, head_dim) + if past.ndim > 2: # this means we use kvcache, otherwise we just pass in a placeholder, but not actually using kvcache + pk, pv = past + k = torch.cat([pk, k], dim=-2) + v = torch.cat([pv, v], dim=-2) + else: + present = None + attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal=False) + attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim) + + attn_output = F.linear(attn_output, self.out_proj.weight, self.out_proj.bias) + attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) + if not is_batched: + # squeeze the output if input was unbatched + attn_output = attn_output.squeeze(1) + # if self.training: + # return attn_output, None + # else: + # return (attn_output, present), None + + # harded coded, the code do not support returning attn weigths yet + attn_output_weights=None + if self.batch_first and is_batched: + return attn_output.transpose(1, 0), present + else: + return attn_output, present + diff --git a/models/modules/embedding.py b/models/modules/embedding.py new file mode 100644 index 0000000..96bf1fb --- /dev/null +++ b/models/modules/embedding.py @@ -0,0 +1,98 @@ +# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py +# Copyright 2023 (authors: Feiteng Li) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +import torch.nn as nn + + +class TokenEmbedding(nn.Module): + def __init__( + self, + dim_model: int, + vocab_size: int, + dropout: float = 0.0, + ): + super().__init__() + + self.vocab_size = vocab_size + self.dim_model = dim_model + + self.dropout = torch.nn.Dropout(p=dropout) + self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model) + + @property + def weight(self) -> torch.Tensor: + return self.word_embeddings.weight + + def embedding(self, index: int) -> torch.Tensor: + return self.word_embeddings.weight[index : index + 1] + + def forward(self, x: torch.Tensor): + X = self.word_embeddings(x) + X = self.dropout(X) + + return X + + +class SinePositionalEmbedding(nn.Module): + def __init__( + self, + dim_model: int, + dropout: float = 0.0, + scale: bool = False, + alpha: bool = False, + ): + super().__init__() + self.dim_model = dim_model + self.x_scale = math.sqrt(dim_model) if scale else 1.0 + self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha) + self.dropout = torch.nn.Dropout(p=dropout) + + self.reverse = False + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, 4000)) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.dim_model) + if self.reverse: + position = torch.arange( + x.size(1) - 1, -1, -1.0, dtype=torch.float32 + ).unsqueeze(1) + else: + position = torch.arange( + 0, x.size(1), dtype=torch.float32 + ).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.dim_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.dim_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype).detach() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + self.extend_pe(x) + output = x.unsqueeze(-1) if x.ndim == 2 else x + output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)] + return self.dropout(output) \ No newline at end of file diff --git a/models/modules/sampling.py b/models/modules/sampling.py new file mode 100644 index 0000000..7acdcd4 --- /dev/null +++ b/models/modules/sampling.py @@ -0,0 +1,63 @@ +import torch +import torch.nn.functional as F + +def top_k_top_p_filtering( + logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 +): + """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (batch size, vocabulary size) + if top_k > 0: keep only top k tokens with highest probability (top-k filtering). + if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + Make sure we keep at least min_tokens_to_keep per batch example in the output + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + """ + if top_k > 0: + top_k = min( + max(top_k, min_tokens_to_keep), logits.size(-1) + ) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum( + F.softmax(sorted_logits, dim=-1), dim=-1 + ) + + # Remove tokens with cumulative probability above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs > top_p + if min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ + ..., :-1 + ].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) + logits[indices_to_remove] = filter_value + return logits + +def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0): + # temperature: (`optional`) float + # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. + # top_k: (`optional`) int + # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50. + # top_p: (`optional`) float + # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1. + + # Temperature (higher temperature => more likely to sample low probability tokens) + if temperature != 1.0: + logits = logits / temperature + # Top-p/top-k filtering + logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) + # Sample + token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) + return token \ No newline at end of file diff --git a/models/modules/scaling.py b/models/modules/scaling.py new file mode 100644 index 0000000..cd245ea --- /dev/null +++ b/models/modules/scaling.py @@ -0,0 +1,1406 @@ +# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/scaling.py +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import collections +import logging +import random +import math +from functools import reduce +from itertools import repeat +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn import Embedding as ScaledEmbedding + +# from valle.utils import Transpose + +class Transpose(nn.Identity): + """(N, T, D) -> (N, D, T)""" + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return input.transpose(1, 2) + +class ActivationBalancerFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + scale_factor: Tensor, + sign_factor: Optional[Tensor], + channel_dim: int, + ) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + ctx.channel_dim = channel_dim + xgt0 = x > 0 + if sign_factor is None: + ctx.save_for_backward(xgt0, scale_factor) + else: + ctx.save_for_backward(xgt0, scale_factor, sign_factor) + return x + + @staticmethod + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: + if len(ctx.saved_tensors) == 3: + xgt0, scale_factor, sign_factor = ctx.saved_tensors + for _ in range(ctx.channel_dim, x_grad.ndim - 1): + scale_factor = scale_factor.unsqueeze(-1) + sign_factor = sign_factor.unsqueeze(-1) + factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) + else: + xgt0, scale_factor = ctx.saved_tensors + for _ in range(ctx.channel_dim, x_grad.ndim - 1): + scale_factor = scale_factor.unsqueeze(-1) + factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) + neg_delta_grad = x_grad.abs() * factor + return ( + x_grad - neg_delta_grad, + None, + None, + None, + ) + + +def _compute_scale_factor( + x: Tensor, + channel_dim: int, + min_abs: float, + max_abs: float, + gain_factor: float, + max_factor: float, +) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + sum_dims = [d for d in range(x.ndim) if d != channel_dim] + x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32) + + if min_abs == 0.0: + below_threshold = 0.0 + else: + # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if + # x_abs)_mean , min_abs. + below_threshold = ( + (min_abs - x_abs_mean) * (gain_factor / min_abs) + ).clamp(min=0, max=max_factor) + + above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp( + min=0, max=max_factor + ) + + return below_threshold - above_threshold + + +def _compute_sign_factor( + x: Tensor, + channel_dim: int, + min_positive: float, + max_positive: float, + gain_factor: float, + max_factor: float, +) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + sum_dims = [d for d in range(x.ndim) if d != channel_dim] + proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims) + if min_positive == 0.0: + factor1 = 0.0 + else: + # 0 if proportion_positive >= min_positive, else can be + # as large as max_factor. + factor1 = ( + (min_positive - proportion_positive) * (gain_factor / min_positive) + ).clamp_(min=0, max=max_factor) + + if max_positive == 1.0: + factor2 = 0.0 + else: + # 0 if self.proportion_positive <= max_positive, else can be + # as large as -max_factor. + factor2 = ( + (proportion_positive - max_positive) + * (gain_factor / (1.0 - max_positive)) + ).clamp_(min=0, max=max_factor) + sign_factor = factor1 - factor2 + # require min_positive != 0 or max_positive != 1: + assert not isinstance(sign_factor, float) + return sign_factor + + +class ActivationScaleBalancerFunction(torch.autograd.Function): + """ + This object is used in class ActivationBalancer when the user specified + min_positive=0, max_positive=1, so there are no constraints on the signs + of the activations and only the absolute value has a constraint. + """ + + @staticmethod + def forward( + ctx, + x: Tensor, + sign_factor: Tensor, + scale_factor: Tensor, + channel_dim: int, + ) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + ctx.channel_dim = channel_dim + xgt0 = x > 0 + ctx.save_for_backward(xgt0, sign_factor, scale_factor) + return x + + @staticmethod + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: + xgt0, sign_factor, scale_factor = ctx.saved_tensors + for _ in range(ctx.channel_dim, x_grad.ndim - 1): + sign_factor = sign_factor.unsqueeze(-1) + scale_factor = scale_factor.unsqueeze(-1) + + factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) + neg_delta_grad = x_grad.abs() * factor + return ( + x_grad - neg_delta_grad, + None, + None, + None, + ) + + +class RandomClampFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + min: Optional[float], + max: Optional[float], + prob: float, + reflect: float, + ) -> Tensor: + x_clamped = torch.clamp(x, min=min, max=max) + mask = torch.rand_like(x) < prob + ans = torch.where(mask, x_clamped, x) + if x.requires_grad: + ctx.save_for_backward(ans == x) + ctx.reflect = reflect + if reflect != 0.0: + ans = ans * (1.0 + reflect) - (x * reflect) + return ans + + @staticmethod + def backward( + ctx, ans_grad: Tensor + ) -> Tuple[Tensor, None, None, None, None]: + (is_same,) = ctx.saved_tensors + x_grad = ans_grad * is_same.to(ans_grad.dtype) + reflect = ctx.reflect + if reflect != 0.0: + x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect) + return x_grad, None, None, None, None + + +def random_clamp( + x: Tensor, + min: Optional[float] = None, + max: Optional[float] = None, + prob: float = 0.5, + reflect: float = 0.0, +): + return RandomClampFunction.apply(x, min, max, prob, reflect) + + +def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: + """ + A randomized way of casting a floating point value to half precision. + """ + if x.dtype == torch.float16: + return x + x_abs = x.abs() + is_too_small = x_abs < min_abs + # for elements where is_too_small is true, random_val will contain +-min_abs with + # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, + # for those elements]. + random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs) + return torch.where(is_too_small, random_val, x).to(torch.float16) + + +class RandomGradFunction(torch.autograd.Function): + """ + Does nothing in forward pass; in backward pass, gets rid of very small grads using + randomized approach that preserves expectations (intended to reduce roundoff). + """ + + @staticmethod + def forward(ctx, x: Tensor, min_abs: float) -> Tensor: + ctx.min_abs = min_abs + return x + + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]: + if ans_grad.dtype == torch.float16: + return ( + random_cast_to_half( + ans_grad.to(torch.float32), min_abs=ctx.min_abs + ), + None, + ) + else: + return ans_grad, None + + +class RandomGrad(torch.nn.Module): + """ + Gets rid of very small gradients using an expectation-preserving method, intended to increase + accuracy of training when using amp (automatic mixed precision) + """ + + def __init__(self, min_abs: float = 5.0e-06): + super(RandomGrad, self).__init__() + self.min_abs = min_abs + + def forward(self, x: Tensor): + if ( + torch.jit.is_scripting() + or not self.training + or torch.jit.is_tracing() + ): + return x + else: + return RandomGradFunction.apply(x, self.min_abs) + + +class SoftmaxFunction(torch.autograd.Function): + """ + Tries to handle half-precision derivatives in a randomized way that should + be more accurate for training than the default behavior. + """ + + @staticmethod + def forward(ctx, x: Tensor, dim: int): + ans = x.softmax(dim=dim) + # if x dtype is float16, x.softmax() returns a float32 because + # (presumably) that op does not support float16, and autocast + # is enabled. + if torch.is_autocast_enabled(): + ans = ans.to(torch.float16) + ctx.save_for_backward(ans) + ctx.x_dtype = x.dtype + ctx.dim = dim + return ans + + @staticmethod + def backward(ctx, ans_grad: Tensor): + (ans,) = ctx.saved_tensors + with torch.cuda.amp.autocast(enabled=False): + ans_grad = ans_grad.to(torch.float32) + ans = ans.to(torch.float32) + x_grad = ans_grad * ans + x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) + return x_grad, None + + +def softmax(x: Tensor, dim: int): + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x.softmax(dim) + + return SoftmaxFunction.apply(x, dim) + + +class MaxEigLimiterFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + coeffs: Tensor, + direction: Tensor, + channel_dim: int, + grad_scale: float, + ) -> Tensor: + ctx.channel_dim = channel_dim + ctx.grad_scale = grad_scale + ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) + return x + + @staticmethod + def backward(ctx, x_grad, *args): + with torch.enable_grad(): + (x_orig, coeffs, new_direction) = ctx.saved_tensors + x_orig.requires_grad = True + num_channels = x_orig.shape[ctx.channel_dim] + x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) + new_direction.requires_grad = False + x = x - x.mean(dim=0) + x_var = (x ** 2).mean() + x_residual = x - coeffs * new_direction + x_residual_var = (x_residual ** 2).mean() + # `variance_proportion` is the proportion of the variance accounted for + # by the top eigen-direction. This is to be minimized. + variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) + variance_proportion.backward() + x_orig_grad = x_orig.grad + x_extra_grad = ( + x_orig.grad + * ctx.grad_scale + * x_grad.norm() + / (x_orig_grad.norm() + 1.0e-20) + ) + return x_grad + x_extra_grad.detach(), None, None, None, None + + +class BasicNorm(torch.nn.Module): + """ + This is intended to be a simpler, and hopefully cheaper, replacement for + LayerNorm. The observation this is based on, is that Transformer-type + networks, especially with pre-norm, sometimes seem to set one of the + feature dimensions to a large constant value (e.g. 50), which "defeats" + the LayerNorm because the output magnitude is then not strongly dependent + on the other (useful) features. Presumably the weight and bias of the + LayerNorm are required to allow it to do this. + + So the idea is to introduce this large constant value as an explicit + parameter, that takes the role of the "eps" in LayerNorm, so the network + doesn't have to do this trick. We make the "eps" learnable. + + Args: + num_channels: the number of channels, e.g. 512. + channel_dim: the axis/dimension corresponding to the channel, + interprted as an offset from the input's ndim if negative. + shis is NOT the num_channels; it should typically be one of + {-2, -1, 0, 1, 2, 3}. + eps: the initial "epsilon" that we add as ballast in: + scale = ((input_vec**2).mean() + epsilon)**-0.5 + Note: our epsilon is actually large, but we keep the name + to indicate the connection with conventional LayerNorm. + learn_eps: if true, we learn epsilon; if false, we keep it + at the initial value. + eps_min: float + eps_max: float + """ + + def __init__( + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + eps: float = 0.25, + learn_eps: bool = True, + eps_min: float = -3.0, + eps_max: float = 3.0, + ) -> None: + super(BasicNorm, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + if learn_eps: + self.eps = nn.Parameter(torch.tensor(eps).log().detach()) + else: + self.register_buffer("eps", torch.tensor(eps).log().detach()) + self.eps_min = eps_min + self.eps_max = eps_max + + def forward(self, x: Tensor) -> Tensor: + assert x.shape[self.channel_dim] == self.num_channels + eps = self.eps + if self.training and random.random() < 0.25: + # with probability 0.25, in training mode, clamp eps between the min + # and max; this will encourage it to learn parameters within the + # allowed range by making parameters that are outside the allowed + # range noisy. + + # gradients to allow the parameter to get back into the allowed region if it happens to exit it. + eps = eps.clamp(min=self.eps_min, max=self.eps_max) + scales = ( + torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp() + ) ** -0.5 + return x * scales + + +def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: + """ + Behaves like a constructor of a modified version of nn.Linear + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Linear(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_( + ans.bias, -0.1 * initial_scale, 0.1 * initial_scale + ) + return ans + + +def ScaledConv1d( + *args, + initial_scale: float = 1.0, + kernel_size: int = 3, + padding: str = "same", + **kwargs, +) -> nn.Conv1d: + """ + Behaves like a constructor of a modified version of nn.Conv1d + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_( + ans.bias, -0.1 * initial_scale, 0.1 * initial_scale + ) + return ans + + +def TransposeScaledConv1d( + *args, + initial_scale: float = 1.0, + kernel_size: int = 3, + padding: str = "same", + **kwargs, +) -> nn.Sequential: + """ + Transpose -> ScaledConv1d + """ + return nn.Sequential( + Transpose(), + ScaledConv1d( + *args, + initial_scale=initial_scale, + kernel_size=kernel_size, + padding=padding, + **kwargs, + ), + ) + + +def ScaledConv1dTranspose( + *args, + initial_scale: float = 1.0, + kernel_size: int = 3, + padding: str = "same", + **kwargs, +) -> nn.Sequential: + """ + Transpose -> ScaledConv1d + """ + return nn.Sequential( + ScaledConv1d( + *args, + initial_scale=initial_scale, + kernel_size=kernel_size, + padding=padding, + **kwargs, + ), + Transpose(), + ) + + +def TransposeConv1d( + *args, kernel_size: int = 3, padding: str = "same", **kwargs +) -> nn.Sequential: + """ + Transpose -> Conv1d + """ + return nn.Sequential( + Transpose(), + nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs), + ) + + +def Conv1dTranspose( + *args, kernel_size: int = 3, padding: str = "same", **kwargs +) -> nn.Sequential: + """ + ScaledConv1d -> Transpose + """ + return nn.Sequential( + nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs), + Transpose(), + ) + + +class SRLinear(nn.Linear): + """https://arxiv.org/abs/2303.06296 + Stabilizing Transformer Training by Preventing Attention Entropy Collapse + """ + + def __init__(self, in_features, out_features, bias=True, **kwargs): + super().__init__(in_features, out_features, bias=bias, **kwargs) + self.register_buffer( + "u", nn.functional.normalize(torch.randn(in_features), dim=0) + ) + with torch.no_grad(): + sigma = self.get_sigma() + self.register_buffer("spectral_norm", sigma) + self.sigma = nn.Parameter(torch.ones(1)) + + def get_sigma(self): + with torch.no_grad(): + u = self.u + v = self.weight.mv(u) + v = nn.functional.normalize(v, dim=0) + u = self.weight.T.mv(v) + u = nn.functional.normalize(u, dim=0) + self.u.data.copy_(u) + return torch.einsum("c,cd,d->", v, self.weight, u) + + def get_weight(self): + sigma = self.get_sigma() + if self.training: + self.spectral_norm.data.copy_(sigma) + weight = (self.sigma / sigma) * self.weight + return weight + + def forward(self, x): + return nn.functional.linear(x, self.get_weight(), self.bias) + + +class SRConv1d(SRLinear): + def __init__( + self, + in_features, + out_features, + kernel_size, + stride: int = 1, + padding: str = "same", + bias: bool = True, + **kwargs, + ): + in_features = in_features * kernel_size + super().__init__(in_features, out_features, bias=bias, **kwargs) + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + + def forward(self, x): + in_features = self.in_features // self.kernel_size + weight = self.get_weight().view( + self.out_features, in_features, self.kernel_size + ) + return nn.functional.conv1d( + x, weight, bias=self.bias, stride=self.stride, padding=self.padding + ) + + +def TransposeSRConv1d( + *args, kernel_size: int = 3, padding: str = "same", **kwargs +) -> nn.Sequential: + """ + Transpose -> SRConv1d + """ + return nn.Sequential( + Transpose(), + SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs), + ) + + +def SRConv1dTranspose( + *args, kernel_size: int = 3, padding: str = "same", **kwargs +) -> nn.Sequential: + """ + SRConv1d -> Transpose + """ + return nn.Sequential( + SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs), + Transpose(), + ) + + +class ActivationBalancer(torch.nn.Module): + """ + Modifies the backpropped derivatives of a function to try to encourage, for + each channel, that it is positive at least a proportion `threshold` of the + time. It does this by multiplying negative derivative values by up to + (1+max_factor), and positive derivative values by up to (1-max_factor), + interpolated from 1 at the threshold to those extremal values when none + of the inputs are positive. + + Args: + num_channels: the number of channels + channel_dim: the dimension/axis corresponding to the channel, e.g. + -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. + min_positive: the minimum, per channel, of the proportion of the time + that (x > 0), below which we start to modify the derivatives. + max_positive: the maximum, per channel, of the proportion of the time + that (x > 0), above which we start to modify the derivatives. + max_factor: the maximum factor by which we modify the derivatives for + either the sign constraint or the magnitude constraint; + e.g. with max_factor=0.02, the the derivatives would be multiplied by + values in the range [0.98..1.02]. + sign_gain_factor: determines the 'gain' with which we increase the + change in gradient once the constraints on min_positive and max_positive + are violated. + scale_gain_factor: determines the 'gain' with which we increase the + change in gradient once the constraints on min_abs and max_abs + are violated. + min_abs: the minimum average-absolute-value difference from the mean + value per channel, which we allow, before we start to modify + the derivatives to prevent this. + max_abs: the maximum average-absolute-value difference from the mean + value per channel, which we allow, before we start to modify + the derivatives to prevent this. + min_prob: determines the minimum probability with which we modify the + gradients for the {min,max}_positive and {min,max}_abs constraints, + on each forward(). This is done randomly to prevent all layers + from doing it at the same time. Early in training we may use + higher probabilities than this; it will decay to this value. + """ + + def __init__( + self, + num_channels: int, + channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.04, + sign_gain_factor: float = 0.01, + scale_gain_factor: float = 0.02, + min_abs: float = 0.2, + max_abs: float = 100.0, + min_prob: float = 0.1, + ): + super(ActivationBalancer, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + self.min_positive = min_positive + self.max_positive = max_positive + self.max_factor = max_factor + self.min_abs = min_abs + self.max_abs = max_abs + self.min_prob = min_prob + self.sign_gain_factor = sign_gain_factor + self.scale_gain_factor = scale_gain_factor + + # count measures how many times the forward() function has been called. + # We occasionally sync this to a tensor called `count`, that exists to + # make sure it is synced to disk when we load and save the model. + self.cpu_count = 0 + self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) + + def forward(self, x: Tensor) -> Tensor: + if ( + torch.jit.is_scripting() + or not x.requires_grad + or torch.jit.is_tracing() + ): + return _no_op(x) + + count = self.cpu_count + self.cpu_count += 1 + + if random.random() < 0.01: + # Occasionally sync self.cpu_count with self.count. + # count affects the decay of 'prob'. don't do this on every iter, + # because syncing with the GPU is slow. + self.cpu_count = max(self.cpu_count, self.count.item()) + self.count.fill_(self.cpu_count) + + # the prob of doing some work exponentially decreases from 0.5 till it hits + # a floor at min_prob (==0.1, by default) + prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0))) + + if random.random() < prob: + sign_gain_factor = 0.5 + if self.min_positive != 0.0 or self.max_positive != 1.0: + sign_factor = _compute_sign_factor( + x, + self.channel_dim, + self.min_positive, + self.max_positive, + gain_factor=self.sign_gain_factor / prob, + max_factor=self.max_factor, + ) + else: + sign_factor = None + + scale_factor = _compute_scale_factor( + x.detach(), + self.channel_dim, + min_abs=self.min_abs, + max_abs=self.max_abs, + gain_factor=self.scale_gain_factor / prob, + max_factor=self.max_factor, + ) + return ActivationBalancerFunction.apply( + x, + scale_factor, + sign_factor, + self.channel_dim, + ) + else: + return _no_op(x) + + +def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor: + """ + Returns x unmodified, but in backprop will put a penalty for the excess of + the absolute values of elements of x over the limit "limit". E.g. if + limit == 10.0, then if x has any values over 10 it will get a penalty. + + Caution: the value of this penalty will be affected by grad scaling used + in automatic mixed precision training. For this reasons we use this, + it shouldn't really matter, or may even be helpful; we just use this + to disallow really implausible values of scores to be given to softmax. + """ + x_sign = x.sign() + over_limit = (x.abs() - limit) > 0 + # The following is a memory efficient way to penalize the absolute values of + # x that's over the limit. (The memory efficiency comes when you think + # about which items torch needs to cache for the autograd, and which ones it + # can throw away). The numerical value of aux_loss as computed here will + # actually be larger than it should be, by limit * over_limit.sum(), but it + # has the same derivative as the real aux_loss which is penalty * (x.abs() - + # limit).relu(). + aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) + # note: we don't do sum() here on aux)_loss, but it's as if we had done + # sum() due to how with_loss() works. + x = with_loss(x, aux_loss) + # you must use x for something, or this will be ineffective. + return x + + +def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. + if x.ndim == 2: + return x.diag() + else: + (batch, dim, dim) = x.shape + x = x.reshape(batch, dim * dim) + x = x[:, :: dim + 1] + assert x.shape == (batch, dim) + return x + + +def _whitening_metric(x: Tensor, num_groups: int): + """ + Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of + of the centered feature covariance are the same within each group's covariance matrix + and also between groups. + Args: + x: a Tensor of shape (*, num_channels) + num_groups: the number of groups of channels, a number >=1 that divides num_channels + Returns: + Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and + greater than 1.0 otherwise. + """ + assert x.dtype != torch.float16 + x = x.reshape(-1, x.shape[-1]) + (num_frames, num_channels) = x.shape + assert num_channels % num_groups == 0 + channels_per_group = num_channels // num_groups + x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1) + # x now has shape (num_groups, num_frames, channels_per_group) + # subtract the mean so we use the centered, not uncentered, covariance. + # My experience has been that when we "mess with the gradients" like this, + # it's better not do anything that tries to move the mean around, because + # that can easily cause instability. + x = x - x.mean(dim=1, keepdim=True) + # x_covar: (num_groups, channels_per_group, channels_per_group) + x_covar = torch.matmul(x.transpose(1, 2), x) + x_covar_mean_diag = _diag(x_covar).mean() + # the following expression is what we'd get if we took the matrix product + # of each covariance and measured the mean of its trace, i.e. + # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). + x_covarsq_mean_diag = (x_covar ** 2).sum() / ( + num_groups * channels_per_group + ) + # this metric will be >= 1.0; the larger it is, the less 'white' the data was. + metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20) + return metric + + +class WhiteningPenaltyFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + num_groups: int, + whitening_limit: float, + grad_scale: float, + ) -> Tensor: + ctx.save_for_backward(x) + ctx.num_groups = num_groups + ctx.whitening_limit = whitening_limit + ctx.grad_scale = grad_scale + return x + + @staticmethod + def backward(ctx, x_grad: Tensor): + (x_orig,) = ctx.saved_tensors + with torch.enable_grad(): + with torch.cuda.amp.autocast(enabled=False): + x_detached = x_orig.to(torch.float32).detach() + x_detached.requires_grad = True + + metric = _whitening_metric(x_detached, ctx.num_groups) + + if random.random() < 0.005 or __name__ == "__main__": + logging.info( + f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " + f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}" + ) + + (metric - ctx.whitening_limit).relu().backward() + penalty_grad = x_detached.grad + scale = ctx.grad_scale * ( + x_grad.to(torch.float32).norm() + / (penalty_grad.norm() + 1.0e-20) + ) + penalty_grad = penalty_grad * scale + return x_grad + penalty_grad.to(x_grad.dtype), None, None, None + + +class Whiten(nn.Module): + def __init__( + self, + num_groups: int, + whitening_limit: float, + prob: Union[float, Tuple[float, float]], + grad_scale: float, + ): + """ + Args: + num_groups: the number of groups to divide the channel dim into before + whitening. We will attempt to make the feature covariance + within each group, after mean subtraction, as "white" as possible, + while having the same trace across all groups. + whitening_limit: a value greater than 1.0, that dictates how much + freedom we have to violate the constraints. 1.0 would mean perfectly + white, with exactly the same trace across groups; larger values + give more freedom. E.g. 2.0. + prob: the probability with which we apply the gradient modification + (also affects the grad scale). May be supplied as a float, + or as a pair (min_prob, max_prob) + + grad_scale: determines the scale on the gradient term from this object, + relative to the rest of the gradient on the attention weights. + E.g. 0.02 (you may want to use smaller values than this if prob is large) + """ + super(Whiten, self).__init__() + assert num_groups >= 1 + assert whitening_limit >= 1 + assert grad_scale >= 0 + self.num_groups = num_groups + self.whitening_limit = whitening_limit + if isinstance(prob, float): + assert 0 < prob <= 1 + self.prob = prob + else: + (self.min_prob, self.max_prob) = prob + assert 0 < self.min_prob < self.max_prob <= 1 + self.prob = self.max_prob + + self.grad_scale = grad_scale + + def forward(self, x: Tensor) -> Tensor: + """ + In the forward pass, this function just returns the input unmodified. + In the backward pass, it will modify the gradients to ensure that the + distribution in each group has close to (lambda times I) as the covariance + after mean subtraction, with the same lambda across groups. + For whitening_limit > 1, there will be more freedom to violate this + constraint. + + Args: + x: the input of shape (*, num_channels) + + Returns: + x, unmodified. You should make sure + you use the returned value, or the graph will be freed + and nothing will happen in backprop. + """ + if ( + not x.requires_grad + or random.random() > self.prob + or self.grad_scale == 0 + ): + return _no_op(x) + else: + if hasattr(self, "min_prob") and random.random() < 0.25: + # occasionally switch between min_prob and max_prob, based on whether + # we are above or below the threshold. + if ( + _whitening_metric(x.to(torch.float32), self.num_groups) + > self.whitening_limit + ): + # there would be a change to the grad. + self.prob = self.max_prob + else: + self.prob = self.min_prob + + return WhiteningPenaltyFunction.apply( + x, self.num_groups, self.whitening_limit, self.grad_scale + ) + + +class WithLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, y: Tensor): + ctx.y_shape = y.shape + return x + + @staticmethod + def backward(ctx, ans_grad: Tensor): + return ans_grad, torch.ones( + ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device + ) + + +def with_loss(x, y): + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x + # returns x but adds y.sum() to the loss function. + return WithLoss.apply(x, y) + + +def _no_op(x: Tensor) -> Tensor: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x + else: + # a no-op function that will have a node in the autograd graph, + # to avoid certain bugs relating to backward hooks + return x.chunk(1, dim=-1)[0] + + +class Identity(torch.nn.Module): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return _no_op(x) + + +class MaxEig(torch.nn.Module): + """ + Modifies the backpropped derivatives of a function to try to discourage + that any given direction in activation space accounts for more than + a specified proportion of the covariance (e.g. 0.2). + + + Args: + num_channels: the number of channels + channel_dim: the dimension/axis corresponding to the channel, e.g. + -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. + max_var_per_eig: the maximum proportion of the variance of the + features/channels, after mean subtraction, that can come from + any given eigenvalue. + min_prob: the minimum probability with which we apply this during any invocation + of forward(), assuming last time we applied the constraint it was + not active; supplied for speed. + scale: determines the scale with which we modify the gradients, relative + to the existing / unmodified gradients + """ + + def __init__( + self, + num_channels: int, + channel_dim: int, + max_var_per_eig: float = 0.2, + min_prob: float = 0.01, + scale: float = 0.01, + ): + super(MaxEig, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + self.scale = scale + assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels + self.max_var_per_eig = max_var_per_eig + + # we figure out the dominant direction using the power method: starting with + # a random vector, keep multiplying by the covariance and renormalizing. + with torch.no_grad(): + # arbitrary.. would use randn() but want to leave the rest of the model's + # random parameters unchanged for comparison + direction = torch.arange(num_channels).to(torch.float) + direction = direction / direction.norm() + self.register_buffer("max_eig_direction", direction) + + self.min_prob = min_prob + # cur_prob is the current probability we'll use to apply the ActivationBalancer. + # We'll regress this towards prob, each tiem we try to apply it and it is not + # active. + self.cur_prob = 1.0 + + def forward(self, x: Tensor) -> Tensor: + if ( + torch.jit.is_scripting() + or self.max_var_per_eig <= 0 + or random.random() > self.cur_prob + or torch.jit.is_tracing() + ): + return _no_op(x) + + with torch.cuda.amp.autocast(enabled=False): + eps = 1.0e-20 + orig_x = x + x = x.to(torch.float32) + with torch.no_grad(): + x = x.transpose(self.channel_dim, -1).reshape( + -1, self.num_channels + ) + x = x - x.mean(dim=0) + new_direction, coeffs = self._find_direction_coeffs( + x, self.max_eig_direction + ) + x_var = (x ** 2).mean() + x_residual = x - coeffs * new_direction + x_residual_var = (x_residual ** 2).mean() + + # `variance_proportion` is the proportion of the variance accounted for + # by the top eigen-direction. + variance_proportion = (x_var - x_residual_var) / ( + x_var + 1.0e-20 + ) + + # ensure new direction is nonzero even if x == 0, by including `direction`. + self._set_direction( + 0.1 * self.max_eig_direction + new_direction + ) + + if random.random() < 0.01 or __name__ == "__main__": + logging.info( + f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}" + ) + + if variance_proportion >= self.max_var_per_eig: + # The constraint is active. Note, we should quite rarely + # reach here, only near the beginning of training if we are + # starting to diverge, should this constraint be active. + cur_prob = self.cur_prob + self.cur_prob = ( + 1.0 # next time, do the update with probability 1.0. + ) + return MaxEigLimiterFunction.apply( + orig_x, coeffs, new_direction, self.channel_dim, self.scale + ) + else: + # let self.cur_prob exponentially approach self.min_prob, as + # long as the constraint is inactive. + self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob + return orig_x + + def _set_direction(self, direction: Tensor): + """ + Sets self.max_eig_direction to a normalized version of `direction` + """ + direction = direction.detach() + direction = direction / direction.norm() + direction_sum = direction.sum().item() + if direction_sum - direction_sum == 0: # no inf/nan + self.max_eig_direction[:] = direction + else: + logging.info( + f"Warning: sum of direction in MaxEig is {direction_sum}, " + "num_channels={self.num_channels}, channel_dim={self.channel_dim}" + ) + + def _find_direction_coeffs( + self, x: Tensor, prev_direction: Tensor + ) -> Tuple[Tensor, Tensor, Tensor]: + """ + Figure out (an approximation to) the proportion of the variance of a set of + feature vectors that can be attributed to the top eigen-direction. + Args: + x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. + prev_direction: a Tensor of shape (num_channels,), that is our previous estimate + of the top eigen-direction, or a random direction if this is the first + iteration. Does not have to be normalized, but should be nonzero. + + Returns: (cur_direction, coeffs), where: + cur_direction: a Tensor of shape (num_channels,) that is the current + estimate of the top eigen-direction. + coeffs: a Tensor of shape (num_frames, 1) that minimizes, or + approximately minimizes, (x - coeffs * cur_direction).norm() + """ + (num_frames, num_channels) = x.shape + assert num_channels > 1 and num_frames > 1 + assert prev_direction.shape == (num_channels,) + # `coeffs` are the coefficients of `prev_direction` in x. + # actually represent the coeffs up to a constant positive factor. + coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10 + cur_direction = (x * coeffs).sum(dim=0) / ( + (coeffs ** 2).sum() + 1.0e-20 + ) + return cur_direction, coeffs + + +class DoubleSwishFunction(torch.autograd.Function): + """ + double_swish(x) = x * torch.sigmoid(x-1) + This is a definition, originally motivated by its close numerical + similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). + + Memory-efficient derivative computation: + double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) + double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). + Now, s'(x) = s(x) * (1-s(x)). + double_swish'(x) = x * s'(x) + s(x). + = x * s(x) * (1-s(x)) + s(x). + = double_swish(x) * (1-s(x)) + s(x) + ... so we just need to remember s(x) but not x itself. + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + x_dtype = x.dtype + if x.dtype == torch.float16: + x = x.to(torch.float32) + + s = torch.sigmoid(x - 1.0) + y = x * s + + if requires_grad: + deriv = y * (1 - s) + s + # notes on derivative of x * sigmoid(x - 1): + # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 + # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund + # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound. + # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which + # floors), should be expectation-preserving. + floor = -0.043637 + ceil = 1.2 + d_scaled = (deriv - floor) * ( + 255.0 / (ceil - floor) + ) + torch.rand_like(deriv) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + (d,) = ctx.saved_tensors + # the same constants as used in forward pass. + floor = -0.043637 + ceil = 1.2 + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d + + +class DoubleSwish(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return double-swish activation function which is an approximation to Swish(Swish(x)), + that we approximate closely with x * sigmoid(x-1). + """ + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x * torch.sigmoid(x - 1.0) + return DoubleSwishFunction.apply(x) + + +def BalancedDoubleSwish( + d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25 +) -> nn.Sequential: + """ + ActivationBalancer -> DoubleSwish + """ + balancer = ActivationBalancer( + d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob + ) + return nn.Sequential( + balancer, + DoubleSwish(), + ) + + +def _test_max_eig(): + for proportion in [0.1, 0.5, 10.0]: + logging.info(f"proportion = {proportion}") + x = torch.randn(100, 128) + direction = torch.randn(128) + coeffs = torch.randn(100, 1) + x += proportion * direction * coeffs + + x.requires_grad = True + + num_channels = 128 + m = MaxEig( + num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig + ) # grad_scale + + for _ in range(4): + y = m(x) + + y_grad = torch.randn_like(x) + y.backward(gradient=y_grad) + + if proportion < 0.2: + assert torch.allclose(x.grad, y_grad, atol=1.0e-02) + elif proportion > 1.0: + assert not torch.allclose(x.grad, y_grad) + + +def _test_whiten(): + for proportion in [0.1, 0.5, 10.0]: + logging.info(f"_test_whiten(): proportion = {proportion}") + x = torch.randn(100, 128) + direction = torch.randn(128) + coeffs = torch.randn(100, 1) + x += proportion * direction * coeffs + + x.requires_grad = True + + num_channels = 128 + m = Whiten( + 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, + ) # grad_scale + + for _ in range(4): + y = m(x) + + y_grad = torch.randn_like(x) + y.backward(gradient=y_grad) + + if proportion < 0.2: + assert torch.allclose(x.grad, y_grad) + elif proportion > 1.0: + assert not torch.allclose(x.grad, y_grad) + + +def _test_activation_balancer_sign(): + probs = torch.arange(0, 1, 0.01) + N = 1000 + x = 1.0 * ( + (2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0 + ) + x = x.detach() + x.requires_grad = True + m = ActivationBalancer( + probs.numel(), + channel_dim=0, + min_positive=0.05, + max_positive=0.95, + max_factor=0.2, + min_abs=0.0, + ) + + y_grad = torch.sign(torch.randn(probs.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_activation_balancer_sign: x = ", x) + print("_test_activation_balancer_sign: y grad = ", y_grad) + print("_test_activation_balancer_sign: x grad = ", x.grad) + + +def _test_activation_balancer_magnitude(): + magnitudes = torch.arange(0, 1, 0.01) + N = 1000 + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( + -1 + ) + x = x.detach() + x.requires_grad = True + m = ActivationBalancer( + magnitudes.numel(), + channel_dim=0, + min_positive=0.0, + max_positive=1.0, + max_factor=0.2, + min_abs=0.2, + max_abs=0.8, + min_prob=1.0, + ) + + y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_activation_balancer_magnitude: x = ", x) + print("_test_activation_balancer_magnitude: y grad = ", y_grad) + print("_test_activation_balancer_magnitude: x grad = ", x.grad) + + +def _test_basic_norm(): + num_channels = 128 + m = BasicNorm(num_channels=num_channels, channel_dim=1) + + x = torch.randn(500, num_channels) + + y = m(x) + + assert y.shape == x.shape + x_rms = (x ** 2).mean().sqrt() + y_rms = (y ** 2).mean().sqrt() + print("x rms = ", x_rms) + print("y rms = ", y_rms) + assert y_rms < x_rms + assert y_rms > 0.5 * x_rms + + +def _test_double_swish_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = DoubleSwish() + + tol = (1.2 - (-0.043637)) / 255.0 + torch.autograd.gradcheck(m, x, atol=tol) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_softmax(): + a = torch.randn(2, 10, dtype=torch.float64) + b = a.clone() + a.requires_grad = True + b.requires_grad = True + a.softmax(dim=1)[:, 0].sum().backward() + print("a grad = ", a.grad) + softmax(b, dim=1)[:, 0].sum().backward() + print("b grad = ", b.grad) + assert torch.allclose(a.grad, b.grad) + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_softmax() + _test_whiten() + _test_max_eig() + _test_activation_balancer_sign() + _test_activation_balancer_magnitude() + _test_basic_norm() + _test_double_swish_deriv() \ No newline at end of file diff --git a/models/modules/transformer.py b/models/modules/transformer.py new file mode 100644 index 0000000..1859258 --- /dev/null +++ b/models/modules/transformer.py @@ -0,0 +1,698 @@ +# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py, modified by Puyuan Peng 2024 +import copy +import numbers +from functools import partial +from typing import Any, Callable, List, Optional, Tuple, Union + +import torch +from torch import Tensor, nn +from torch.nn import functional as F + +from .activation import MultiheadAttention +from .scaling import ActivationBalancer, BalancedDoubleSwish +from .scaling import BasicNorm as _BasicNorm + +_shape_t = Union[int, List[int], torch.Size] + + +class LayerNorm(nn.Module): + __constants__ = ["normalized_shape", "eps", "elementwise_affine"] + normalized_shape: Tuple[int, ...] + eps: float + elementwise_affine: bool + + def __init__( + self, + normalized_shape: _shape_t, + eps: float = 1e-5, + elementwise_affine: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = nn.Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) + self.bias = nn.Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.elementwise_affine: + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + def forward(self, input: Tensor, embedding: Any = None) -> Tensor: + if isinstance(input, tuple): + input, embedding = input + return ( + F.layer_norm( + input, + self.normalized_shape, + self.weight, + self.bias, + self.eps, + ), + embedding, + ) + + assert embedding is None + return F.layer_norm( + input, self.normalized_shape, self.weight, self.bias, self.eps + ) + + def extra_repr(self) -> str: + return ( + "{normalized_shape}, eps={eps}, " + "elementwise_affine={elementwise_affine}".format(**self.__dict__) + ) + + +class AdaptiveLayerNorm(nn.Module): + r"""Adaptive Layer Normalization""" + + def __init__(self, d_model, norm) -> None: + super(AdaptiveLayerNorm, self).__init__() + self.project_layer = nn.Linear(d_model, 2 * d_model) + self.norm = norm + self.d_model = d_model + self.eps = self.norm.eps + + def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: + if isinstance(input, tuple): + input, embedding = input + weight, bias = torch.split( + self.project_layer(embedding), + split_size_or_sections=self.d_model, + dim=-1, + ) + return (weight * self.norm(input) + bias, embedding) + + weight, bias = torch.split( + self.project_layer(embedding), + split_size_or_sections=self.d_model, + dim=-1, + ) + return weight * self.norm(input) + bias + + +class BasicNorm(_BasicNorm): + def __init__( + self, + d_model: int, + eps: float = 1e-5, + device=None, + dtype=None, + ): + super(BasicNorm, self).__init__(d_model, eps=eps) + + def forward(self, input: Tensor, embedding: Any = None) -> Tensor: + if isinstance(input, tuple): + input, embedding = input + return ( + super(BasicNorm, self).forward(input), + embedding, + ) + + assert embedding is None + return super(BasicNorm, self).forward(input) + + +class BalancedBasicNorm(nn.Module): + def __init__( + self, + d_model: int, + eps: float = 1e-5, + device=None, + dtype=None, + ): + super(BalancedBasicNorm, self).__init__() + self.balancer = ActivationBalancer( + d_model, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + max_abs=6.0, + ) + self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype) + + def forward(self, input: Tensor, embedding: Any = None) -> Tensor: + if isinstance(input, tuple): + input, embedding = input + return self.norm((self.balancer(input), embedding)) + + assert embedding is None + return self.norm(self.balancer(input)) + + +class IdentityNorm(nn.Module): + def __init__( + self, + d_model: int, + eps: float = 1e-5, + device=None, + dtype=None, + ) -> None: + super(IdentityNorm, self).__init__() + + def forward(self, input: Tensor, embedding: Any = None) -> Tensor: + if isinstance(input, tuple): + return input + + assert embedding is None + return input + + +class TransformerEncoderLayer(nn.Module): + __constants__ = ["batch_first", "norm_first"] + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + batch_first: bool = False, + norm_first: bool = False, + device=None, + dtype=None, + linear1_self_attention_cls: nn.Module = nn.Linear, + linear2_self_attention_cls: nn.Module = nn.Linear, + linear1_feedforward_cls: nn.Module = nn.Linear, + linear2_feedforward_cls: nn.Module = nn.Linear, + layer_norm_cls: nn.Module = LayerNorm, + layer_norm_eps: float = 1e-5, + adaptive_layer_norm=False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(TransformerEncoderLayer, self).__init__() + self.self_attn = MultiheadAttention( + d_model, + nhead, + dropout=dropout, + batch_first=batch_first, + linear1_cls=linear1_self_attention_cls, + linear2_cls=linear2_self_attention_cls, + **factory_kwargs, + ) + + # Implementation of Feedforward model + self.linear1 = linear1_feedforward_cls( + d_model, dim_feedforward, **factory_kwargs + ) + self.dropout = nn.Dropout(dropout) + self.linear2 = linear2_feedforward_cls( + dim_feedforward, d_model, **factory_kwargs + ) + + self.norm_first = norm_first + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + # Legacy string support for activation function. + if isinstance(activation, str): + activation = _get_activation_fn(activation) + elif isinstance(activation, partial): + activation = activation(d_model) + elif activation == BalancedDoubleSwish: + activation = BalancedDoubleSwish(d_model) + + # # We can't test self.activation in forward() in TorchScript, + # # so stash some information about it instead. + # if activation is F.relu or isinstance(activation, torch.nn.ReLU): + # self.activation_relu_or_gelu = 1 + # elif activation is F.gelu or isinstance(activation, torch.nn.GELU): + # self.activation_relu_or_gelu = 2 + # else: + # self.activation_relu_or_gelu = 0 + self.activation = activation + + norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) + if layer_norm_cls == IdentityNorm: + norm2 = BalancedBasicNorm( + d_model, eps=layer_norm_eps, **factory_kwargs + ) + else: + norm2 = layer_norm_cls( + d_model, eps=layer_norm_eps, **factory_kwargs + ) + + if adaptive_layer_norm: + self.norm1 = AdaptiveLayerNorm(d_model, norm1) + self.norm2 = AdaptiveLayerNorm(d_model, norm2) + else: + self.norm1 = norm1 + self.norm2 = norm2 + + def __setstate__(self, state): + super(TransformerEncoderLayer, self).__setstate__(state) + if not hasattr(self, "activation"): + self.activation = F.relu + + def forward( + self, + src: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + need_weights: Optional[bool] = False, + past: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + x, stage_embedding = src, None + is_src_tuple = False + if isinstance(src, tuple): + x, stage_embedding = src + is_src_tuple = True + + if src_key_padding_mask is not None: + _skpm_dtype = src_key_padding_mask.dtype + if _skpm_dtype != torch.bool and not torch.is_floating_point( + src_key_padding_mask + ): + raise AssertionError( + "only bool and floating types of key_padding_mask are supported" + ) + if need_weights: + if self.norm_first: + out, attn = self._sa_block_attn( + self.norm1(x, stage_embedding), + src_mask, + src_key_padding_mask, + past + ) + out, present = out # present is the kvcache of the present timestep + x = x + out + x = x + self._ff_block(self.norm2(x, stage_embedding)) + else: + out, attn = self._sa_block_attn(x, src_mask, src_key_padding_mask, past) + out, present = out # present is the kvcache of the present timestep + x = self.norm1( + x + out, + stage_embedding, + ) + x = self.norm2(x + self._ff_block(x), stage_embedding) + assert not is_src_tuple + # return (x, stage_embedding) + return (x, attn) + else: + if self.norm_first: + out = self._sa_block( + self.norm1(x, stage_embedding), + src_mask, + src_key_padding_mask, past + ) + out, present = out # present is the kvcache of the present timestep + x = x + out + x = x + self._ff_block(self.norm2(x, stage_embedding)) + else: + out = self._sa_block(x, src_mask, src_key_padding_mask) + out, present = out # present is the kvcache of the present timestep + x = self.norm1( + x + out, + stage_embedding, past + ) + x = self.norm2(x + self._ff_block(x), stage_embedding) + + if is_src_tuple: + x = (x, stage_embedding) + if present != None: + x = [x, present] + return x + + # self-attention block + def _sa_block( + self, + x: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + past: Optional[Tensor] = None, + ) -> Tensor: + x = self.self_attn( + x, + x, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + past=past + ) + x, present = x + return self.dropout1(x), present + + # self-attention block, also return attention weights + def _sa_block_attn( + self, + x: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + past: Optional[Tensor] = None, + ) -> Tensor: + x, attn = self.self_attn( + x, + x, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=True, + past=past + ) + x, present = x + return (self.dropout1(x), present), attn + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) + + +class TransformerEncoder(nn.Module): + r"""TransformerEncoder is a stack of N encoder layers. Users can build the + BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. + + Args: + encoder_layer: an instance of the TransformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + enable_nested_tensor: if True, input will automatically convert to nested tensor + (and convert back on output). This will improve the overall performance of + TransformerEncoder when padding rate is high. Default: ``True`` (enabled). + + Examples:: + >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) + >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = transformer_encoder(src) + """ + __constants__ = ["norm"] + + def __init__(self, encoder_layer, num_layers, norm=None): + super(TransformerEncoder, self).__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + src: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + return_layer_states: bool = False, + need_weights:Optional[bool] = False, + past: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + return_layer_states: return layers' state (optional). + + Shape: + see the docs in Transformer class. + """ + if return_layer_states: + assert not need_weights + layer_states = [] # layers' output + output = src + for mod in self.layers: + output = mod( + output, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + past=past + ) + layer_states.append(output[0]) + + if self.norm is not None: + output = self.norm(output) + + return layer_states, output + if need_weights: + assert not return_layer_states + layer_attn = [] # layers' output + output = src + for mod in self.layers: + output = mod( + output, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + need_weights=True, + past=past + ) + layer_attn.append(output[1]) + + if self.norm is not None: + output = self.norm(output) + + return layer_attn, output + + output = src + all_present = [] + for n_layer, mod in enumerate(self.layers): + output = mod( + output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, past=None if past is None else past[n_layer] + ) + if isinstance(output, list): + output, present = output + all_present.append(present) + + if self.norm is not None: + output = self.norm(output) + if all_present != []: + all_present = torch.stack(all_present, dim=0) # (num_layers, 2, batch_size, num_heads, seq_len, head_dim) + output = [output, all_present] + return output + + +class TransformerDecoderLayer(nn.Module): + __constants__ = ["batch_first", "norm_first"] + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + linear1_self_attention_cls: nn.Module = nn.Linear, + linear2_self_attention_cls: nn.Module = nn.Linear, + linear1_feedforward_cls: nn.Module = nn.Linear, + linear2_feedforward_cls: nn.Module = nn.Linear, + batch_first: bool = False, + norm_first: bool = False, + device=None, + dtype=None, + layer_norm_cls: nn.Module = LayerNorm, + layer_norm_eps: float = 1e-5, + adaptive_layer_norm=False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(TransformerDecoderLayer, self).__init__() + self.self_attn = MultiheadAttention( + d_model, + nhead, + dropout=dropout, + batch_first=batch_first, + linear1_cls=linear1_self_attention_cls, + linear2_cls=linear2_self_attention_cls, + **factory_kwargs, + ) + self.multihead_attn = MultiheadAttention( + d_model, + nhead, + dropout=dropout, + batch_first=batch_first, + linear1_cls=linear1_self_attention_cls, + linear2_cls=linear2_self_attention_cls, + **factory_kwargs, + ) + # Implementation of Feedforward model + self.linear1 = linear1_feedforward_cls( + d_model, dim_feedforward, **factory_kwargs + ) + self.dropout = nn.Dropout(dropout) + self.linear2 = linear2_feedforward_cls( + dim_feedforward, d_model, **factory_kwargs + ) + + self.norm_first = norm_first + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + # Legacy string support for activation function. + if isinstance(activation, str): + self.activation = _get_activation_fn(activation) + elif isinstance(activation, partial): + self.activation = activation(d_model) + elif activation == BalancedDoubleSwish: + self.activation = BalancedDoubleSwish(d_model) + else: + self.activation = activation + + if adaptive_layer_norm: + norm1 = layer_norm_cls( + d_model, eps=layer_norm_eps, **factory_kwargs + ) + norm2 = layer_norm_cls( + d_model, eps=layer_norm_eps, **factory_kwargs + ) + norm3 = layer_norm_cls( + d_model, eps=layer_norm_eps, **factory_kwargs + ) + + self.norm1 = AdaptiveLayerNorm(d_model, norm1) + self.norm2 = AdaptiveLayerNorm(d_model, norm2) + self.norm3 = AdaptiveLayerNorm(d_model, norm3) + else: + self.norm1 = layer_norm_cls( + d_model, eps=layer_norm_eps, **factory_kwargs + ) + self.norm2 = layer_norm_cls( + d_model, eps=layer_norm_eps, **factory_kwargs + ) + if layer_norm_cls == IdentityNorm: + self.norm3 = BalancedBasicNorm( + d_model, eps=layer_norm_eps, **factory_kwargs + ) + else: + self.norm3 = layer_norm_cls( + d_model, eps=layer_norm_eps, **factory_kwargs + ) + + def forward( + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the inputs (and mask) through the decoder layer. + + Args: + tgt: the sequence to the decoder layer (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + tgt_is_tuple = False + if isinstance(tgt, tuple): + x, stage_embedding = tgt + tgt_is_tuple = True + else: + x, stage_embedding = tgt, None + + if self.norm_first: + x = x + self._sa_block( + self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask + ) + x = x + self._mha_block( + self.norm2(x, stage_embedding), + memory, + memory_mask, + memory_key_padding_mask, + ) + x = x + self._ff_block(self.norm3(x, stage_embedding)) + else: + x = self.norm1( + x + self._sa_block(x, tgt_mask, tgt_key_padding_mask), + stage_embedding, + ) + x = self.norm2( + x + + self._mha_block( + x, memory, memory_mask, memory_key_padding_mask + ), + stage_embedding, + ) + x = self.norm3(x + self._ff_block(x), stage_embedding) + + if tgt_is_tuple: + return (x, stage_embedding) + return x + + # self-attention block + def _sa_block( + self, + x: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + ) -> Tensor: + x = self.self_attn( + x, + x, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + )[0] + return self.dropout1(x) + + # multihead attention block + def _mha_block( + self, + x: Tensor, + mem: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + ) -> Tensor: + x = self.multihead_attn( + x, + mem, + mem, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + )[0] + return self.dropout2(x) + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout3(x) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) \ No newline at end of file diff --git a/models/modules/utils.py b/models/modules/utils.py new file mode 100644 index 0000000..8a46980 --- /dev/null +++ b/models/modules/utils.py @@ -0,0 +1,37 @@ +# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py, modified by Puyuan Peng +import torch + + +def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: + """ + Args: + lengths: + A 1-D tensor containing sentence lengths. + max_len: + The length of masks. + Returns: + Return a 2-D bool tensor, where masked positions + are filled with `True` and non-masked positions are + filled with `False`. + + >>> lengths = torch.tensor([1, 3, 2, 5]) + >>> make_pad_mask(lengths) + tensor([[False, True, True, True, True], + [False, False, False, True, True], + [False, False, True, True, True], + [False, False, False, False, False]]) + """ + assert lengths.ndim == 1, lengths.ndim + max_len = max(max_len, lengths.max()) + n = lengths.size(0) + seq_range = torch.arange(0, max_len, device=lengths.device) + expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len) + + return expaned_lengths >= lengths.unsqueeze(-1) + +def generate_partial_autoregressive_mask(sz, start, end): + mask = torch.zeros(sz, sz).bool() + mask[start:end, start:end] = torch.triu(torch.ones(end-start, end-start,dtype=torch.bool), diagonal=1) + mask[:start, start:end] = True + mask[end:, start:end] = True + return mask diff --git a/models/voicecraft.py b/models/voicecraft.py new file mode 100644 index 0000000..4042cae --- /dev/null +++ b/models/voicecraft.py @@ -0,0 +1,1402 @@ +import random + +import numpy as np +import logging +import argparse, copy +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchmetrics.classification import MulticlassAccuracy + +from .modules.utils import make_pad_mask + +from .modules.embedding import SinePositionalEmbedding, TokenEmbedding +from .modules.transformer import ( + LayerNorm, + TransformerEncoder, + TransformerEncoderLayer, +) +from .codebooks_patterns import DelayedPatternProvider + +def top_k_top_p_filtering( + logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 +): + """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (batch size, vocabulary size) + if top_k > 0: keep only top k tokens with highest probability (top-k filtering). + if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + Make sure we keep at least min_tokens_to_keep per batch example in the output + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + """ + if top_k > 0: + top_k = min( + max(top_k, min_tokens_to_keep), logits.size(-1) + ) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum( + F.softmax(sorted_logits, dim=-1), dim=-1 + ) + + # Remove tokens with cumulative probability above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs > top_p + if min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ + ..., :-1 + ].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) + logits[indices_to_remove] = filter_value + return logits + + +def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0): + # temperature: (`optional`) float + # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. + # top_k: (`optional`) int + # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50. + # top_p: (`optional`) float + # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1. + + # Temperature (higher temperature => more likely to sample low probability tokens) + if temperature != 1.0: + logits = logits / temperature + # Top-p/top-k filtering + logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) + # Sample + token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) + return token + + + +class VoiceCraft(nn.Module): + def __init__(self, args): + super().__init__() + self.args = copy.copy(args) + self.pattern = DelayedPatternProvider(n_q=self.args.n_codebooks) + if not getattr(self.args, "special_first", False): + self.args.special_first = 0 + if not getattr(self.args, "n_special", False): + self.args.n_special = 3 + self.args.eos = getattr(self.args, "eos", -1) + self.eog = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eog, dtype=torch.long), requires_grad=False) # [K 1] + if self.args.eos > 0: + assert self.args.eos != self.args.audio_pad_token and self.args.eos != self.args.empty_token, self.args.eos + self.eos = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eos, dtype=torch.long), requires_grad=False) # [K 1] + if type(self.args.audio_vocab_size) == str: + self.args.audio_vocab_size = eval(self.args.audio_vocab_size) + + self.n_text_tokens = self.args.text_vocab_size + 1 + assert self.args.text_pad_token == self.args.text_vocab_size, f"self.args.text_vocab_size: {self.args.text_vocab_size}, self.args.text_pad_token: {self.args.text_pad_token}" + + self.n_audio_tokens = [self.args.audio_vocab_size + self.args.n_special] * self.args.n_codebooks # special tokens: empty token, EOG token, audio pad token + assert self.args.audio_vocab_size == self.args.empty_token, self.args.empty_token + assert self.args.eog == self.args.audio_vocab_size + 1, self.args.eog + assert self.args.audio_pad_token == self.args.audio_vocab_size + 2, self.args.audio_pad_token + + self.text_embedding = TokenEmbedding( + dim_model=self.args.d_model, + vocab_size=self.n_text_tokens, + dropout=self.args.text_embedding_dropout + ) + + self.audio_embedding = nn.ModuleList( + [ + TokenEmbedding( + dim_model=self.args.audio_embedding_dim, + vocab_size=self.n_audio_tokens[k], + dropout=self.args.audio_embedding_dropout + ) for k in range(self.args.n_codebooks) + ] + ) + self.mask_embedding = nn.Parameter(torch.randn(self.args.max_n_spans, self.args.d_model), requires_grad=True) + self.text_positional_embedding = SinePositionalEmbedding( + self.args.d_model, + dropout=self.args.text_positional_embedding_dropout, + scale=False, + alpha=True, # learnable scaler, scale the volume of positional embedding + ) + self.audio_positional_embedding = SinePositionalEmbedding( + self.args.d_model, + dropout=self.args.audio_positional_embedding_dropout, + scale=False, + alpha=True, # learnable scaler, scale the volume of positional embedding + ) + + dec_layer = TransformerEncoderLayer( + self.args.d_model, + self.args.nhead, + dim_feedforward=self.args.d_model * 4, + dropout=self.args.trm_dropout, + batch_first=True, + norm_first=True, + layer_norm_cls=LayerNorm + ) + self.decoder = TransformerEncoder( + dec_layer, + num_layers=self.args.num_decoder_layers, + norm=LayerNorm(self.args.d_model), + ) + + self.predict_layer = nn.ModuleList( + [ + nn.Sequential(nn.Linear(self.args.d_model, self.args.audio_vocab_size//2), nn.GELU(), nn.Linear(self.args.audio_vocab_size//2, self.n_audio_tokens[k])) for k in range(self.args.n_codebooks) + ] + ) + + self.accuracy_metrics = nn.ModuleList( + [MulticlassAccuracy( + self.n_audio_tokens[k], + top_k=10, + average="micro", + multidim_average="global", + ignore_index=None, + ) for k in range(self.args.n_codebooks)] + ) + + + def prepare_mask_intervals(self, y_lens): + mask_intervals = [] + non_mask_intervals = [] + + for i, y_len in enumerate(y_lens): + if self.args.mask_sample_dist == "uniform": + n_spans = random.choice(range(1, self.args.max_n_spans+1)) + elif "poisson" in self.args.mask_sample_dist.lower(): + param = float(self.args.mask_sample_dist[len("poisson"):]) + poisson_sample = torch.poisson(torch.tensor([param])) + n_spans = int(poisson_sample.clamp(1, self.args.max_n_spans).item()) + + starts = random.sample(range(1, y_len-1-self.args.mask_len_min), n_spans) + starts = sorted(starts) + + for j in range(len(starts)-1, 0, -1): + if starts[j] - starts[j-1] < self.args.min_gap: + del starts[j] # If elements are too close, delete the later one + assert len(starts) > 0, f"there is no masked span left, y_len: {y_len}, sampled n_spans: {n_spans}" + + temp_starts = starts + [y_len] + gaps = [temp_starts[j+1] - temp_starts[j] for j in range(len(temp_starts)-1)] + + ends = [] + + for j, (start, gap) in enumerate(zip(starts, gaps)): + mask_len = random.randint(self.args.mask_len_min, self.args.mask_len_max) + # if mask_len > gap * self.args.max_mask_portion: # make sure the masks are not overlapping with each other + if mask_len > gap - 1: # make sure the masks are not overlapping with each other + # temp_mask_start = int(0.6*gap*self.args.max_mask_portion) + # temp_mask_end = int(gap*self.args.max_mask_portion) + temp_mask_start = 1 + temp_mask_end = gap - 1 + mask_len = random.randint(temp_mask_start, temp_mask_end) + ends.append(start + mask_len) + + mask_intervals.append([(s,e) for s,e in zip(starts, ends)]) + non_mask_intervals.append([(ns,ne) for ns, ne in zip([0]+ends, starts+[y_len])]) + + return mask_intervals, non_mask_intervals + + def rearrange(self, y, non_mask_intervals, mask_intervals): + reduced_eog = getattr(self.args, "reduced_eog", 0) + rearranged_y = [] + for i in range(len(y)): + if self.args.eos > 0: + assert reduced_eog + cur_y = [y[i, :, item[0]: item[1]] for item in non_mask_intervals[i][:-1]] + [torch.cat([y[i, :, non_mask_intervals[i][-1][0]: non_mask_intervals[i][-1][1]], self.eos], dim=-1)] + [torch.cat([y[i, :, item[0]: item[1]], self.eog], dim=-1) for item in mask_intervals[i]] # only insert eog to the last non-mask-interval, which is when the utterance actual ends + else: + if reduced_eog: + cur_y = [y[i, :, item[0]: item[1]] for item in non_mask_intervals[i][:-1]] + [torch.cat([y[i, :, non_mask_intervals[i][-1][0]: non_mask_intervals[i][-1][1]], self.eog], dim=-1)] + [torch.cat([y[i, :, item[0]: item[1]], self.eog], dim=-1) for item in mask_intervals[i]] # only insert eog to the last non-mask-interval, which is when the utterance actual ends + else: + cur_y = [torch.cat([y[i, :, item[0]: item[1]], self.eog], dim=-1) for item in non_mask_intervals[i]] + [torch.cat([y[i, :, item[0]: item[1]], self.eog], dim=-1) for item in mask_intervals[i]] # eog is added to each section TODO this is not correct, I should add eog to non_mask_intervals if that segment is not the ending segment (as there is no way for the model to predict eog for those segments, and this will do harm to tts experiment, where the model randomly output eog for the first segment) + rearranged_y.append(cur_y) + return rearranged_y + + def shift(self, rearranged_y): + shifted_y = [] + patterns = [] + for i in range(len(rearranged_y)): + cur_patterns = [self.pattern.get_pattern(cur_y.shape[1]) for cur_y in rearranged_y[i]] + out = [cur_pattern.build_pattern_sequence(z=cur_y.unsqueeze(0).contiguous(), special_token=self.args.empty_token, keep_only_valid_steps=False) for cur_pattern, cur_y in zip(cur_patterns, rearranged_y[i])] + shifted_y.append([item[0].squeeze(0) for item in out]) # the first item is values, later two are indexes and mask + patterns.append(cur_patterns) + return shifted_y, patterns + + def insert_mask(self, shifted_y): + inserted_y = [] + mask_position = [] + mask_value = [] + for i in range(len(shifted_y)): + num_masks = (len(shifted_y[i]) - 1) // 2 + assert num_masks == (len(shifted_y[i]) - 1) / 2, len(shifted_y[i]) + emb_inds = list(range(self.args.max_n_spans)) + if self.args.shuffle_mask_embedding: + random.shuffle(emb_inds) + emb_inds_use = emb_inds[:num_masks] + emb_inds_use = emb_inds_use + emb_inds_use + mask_value.append(emb_inds_use) + cur_inserted_y = [] + cur_mask_position = [] + for j in range(len(shifted_y[i])-1): + cur_inserted_y.append(shifted_y[i][j]) + cur_mask_position.append(sum([item.shape[1] for item in cur_inserted_y])) # each item is of shape [K S], so take shape[1] + cur_inserted_y.append(self.eog) # insert mask token of shape [K, 1], BUT we are actually using the eog token as a place holder here, as the real mask will be inserted in embed_y function + + cur_inserted_y.append(shifted_y[i][-1]) + + inserted_y.append(cur_inserted_y) + mask_position.append(cur_mask_position) + return inserted_y, mask_position, mask_value + + def cat_y(self, inserted_y, mask_position, y_lens): + reduced_eog = getattr(self.args, "reduced_eog", 0) + cated_y = [] + new_y_lens = [] + for i in range(len(inserted_y)): + cur_cated_y = torch.cat(inserted_y[i], dim=1) #[K S] + cur_cated_y = cur_cated_y.transpose(1,0) # [S K] + cur_cated_y_len = cur_cated_y.shape[0] + if reduced_eog: + assert cur_cated_y_len == y_lens[i] + len(mask_position[i]) + (len(mask_position[i]) + 1) * self.args.n_codebooks + (len(mask_position[i])/2 + 1), f"cur_cated_y_len == {cur_cated_y_len}, but it should be y_lens[i] ({y_lens[i]}) + len(mask_position[i]) ({len(mask_position[i])}) + (len(mask_position[i]) + 1) * self.args.n_codebooks ({(len(mask_position[i]) + 1) * self.args.n_codebooks}) + (len(mask_position[i])/2 + 1) ({len(mask_position[i])/2 + 1})={y_lens[i] + len(mask_position[i]) + (len(mask_position[i]) + 1) * self.args.n_codebooks + (len(mask_position[i])/2 + 1)}" + else: + assert cur_cated_y_len == y_lens[i] + len(mask_position[i]) + (len(mask_position[i]) + 1) * self.args.n_codebooks + (len(mask_position[i]) + 1), f"cur_cated_y_len == {cur_cated_y_len}, but it should be y_lens[i] ({y_lens[i]}) + len(mask_position[i]) ({len(mask_position[i])}) + (len(mask_position[i]) + 1) * self.args.n_codebooks ({(len(mask_position[i]) + 1) * self.args.n_codebooks}) + (len(mask_position[i]) + 1) ({len(mask_position[i]) + 1})" # the last term represent the inserted eog token, originally it's inserted at the end of every token, but this is wrong + new_y_lens.append(cur_cated_y_len) + cated_y.append(cur_cated_y) + + cated_y = torch.nn.utils.rnn.pad_sequence(cated_y, batch_first=False, padding_value=self.args.audio_pad_token) + assert cated_y.shape == torch.Size([max(new_y_lens),len(inserted_y), self.args.n_codebooks]), f"cated_y.shape: {cated_y.shape}, but it should be {torch.Size([max(new_y_lens,len(inserted_y), self.args.n_codebooks)])}" + cated_y = cated_y.permute(2,0,1) # [T,B,K]->[K,T,B] + assert cated_y.shape[0] == self.args.n_codebooks, cated_y.shape + return cated_y, torch.LongTensor(new_y_lens).to(cated_y.device) + + def embed_y(self, cated_y, mask_position, mask_value): + embedded_y = torch.stack([self.audio_embedding[k](cated_y[k]) for k in range(self.args.n_codebooks)], dim=0) # [K, T, B, D] + assert embedded_y.shape[0] == self.args.n_codebooks, embedded_y.shape + assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape + embedded_y = embedded_y.sum(dim=0) # [K,T,B,D]->[T,B,D] + embedded_y = embedded_y.transpose(1,0) # [T,B,D]->[B,T,D] + for i in range(len(embedded_y)): + if len(mask_position[i]) > 0: + embedded_y[i, mask_position[i]] = self.mask_embedding[mask_value[i]] + return embedded_y + + def prepare_input_target(self, y, y_lens): + # rearrange y + # assume y shape: [B T K], K is n_codebooks + assert y.shape[1] == self.args.n_codebooks, y.shape + # sample mask_intervals + mask_intervals, non_mask_intervals = self.prepare_mask_intervals(y_lens) + + # need to have EOG in each section (SOG will be generated by the pattern class) + # but mask can be inserted later after we have shifted the input + # y could be rearranged in this way: + # [ + # [tensor[4, 12], tensor[4, 45], tensor[4, 102], tensor[4, 32]], tensor[4, 22]], + # [tensor[4, 44], tensor[4, 56], tensor[4, 19]], + # ... + # ] + # for the first list of tensors (4 tensors), first 3 tensors are non_masked part, last 2 are masked part. + # NOTE #non_masked_part = #masked_part + 1 + # NOTE *these are also the targets* + # added eog at the end of each segment (masked segment and unmasked segment) + rearranged_y = self.rearrange(y, non_mask_intervals, mask_intervals) + targets = rearranged_y # each element in each sample is of shape [K T] + assert targets[0][0].shape[0] == self.args.n_codebooks, targets[0][0].shape + + # next we need to apply pattern shifting to each tensor, after which, we'll replace the starting tokens of each section with a token that's different from the special padding token + # [[5, 1, 2, 3, 4, 5, 5], + # [5, 5, 1, 2, 3, 4, 5], + # [5, 5, 5, 1, 2, 3, 4]] + shifted_y, patterns = self.shift(rearranged_y) # each element [K S] + assert shifted_y[0][0].shape[0] == self.args.n_codebooks, shifted_y[0][0].shape[0] + + + # then, insert mask token at the intersection of each tensor (we want to decide the arrangement of the mask (shuffle or not)), we better have a separate nn.embedding for it + # we also need to record the position of the inserted mask + inserted_y, mask_position, mask_value = self.insert_mask(shifted_y) + assert inserted_y[0][0].shape[0] == self.args.n_codebooks, inserted_y[0][0].shape[0] + assert inserted_y[0][1].shape == torch.Size((self.args.n_codebooks, 1)), f"this should be a mask, so should have shape {(self.args.n_codebooks, 1)}, but it's {inserted_y[0][1].shape}" + + # then concat tensors that belong to the same sample (in order) then get the length of each sample, and then stack them in batch dimension, pad them with pad_token + cated_y, new_y_lens = self.cat_y(inserted_y, mask_position, y_lens) # KTB + assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], len(inserted_y))) + + + # embed remember to separately embed the mask tokens + embedded_y = self.embed_y(cated_y, mask_position, mask_value) #BTD + assert embedded_y.shape[1:] == torch.Size((max(new_y_lens), self.args.d_model)), embedded_y.shape + + # positional embedding + y_input = self.audio_positional_embedding(embedded_y) + + # make attention mask and padding mask + y_padding_mask = make_pad_mask(new_y_lens).to(y.device) + y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y_padding_mask.device) + return y_input, new_y_lens, targets, y_padding_mask, y_attention_mask, mask_position, patterns + + def remove_mask(self, logits, mask_position, new_y_lens): + # logits: [B K S card] + logits_use = [] + for i in range(len(logits)): + non_mask_positions = [-1] + mask_position[i] + [new_y_lens[i]] + non_mask_intervals = [[non_mask_positions[i]+1, non_mask_positions[i+1]] for i in range(len(non_mask_positions)-1)] + cur_logits_use = [logits[i, :, l:r] for l,r in non_mask_intervals] + logits_use.append(cur_logits_use) + + return logits_use + + def revert_pattern(self, patterns, logits_use): + logits_final = [] + logit_masks = [] + for i in range(len(logits_use)): + cur_logits = [ + item.unsqueeze(0).permute(0, 3, 1, 2).contiguous() for item in logits_use[i] + ] # each item is of shape [1 K S card] [1 card K S] + cur_logits_final = [ + cur_pattern.revert_pattern_logits( + item, 0, keep_only_valid_steps=False + ) + for cur_pattern, item in zip(patterns[i], cur_logits) + ] # if input output order doesn't match, this step will give an error + cur_logits_final_ret = [item[0].permute(0,2,3,1).squeeze(0) for item in cur_logits_final] # each element is of shape [K,T,card] + logits_final.append(cur_logits_final_ret) + logit_masks.append([item[2] for item in cur_logits_final]) + + return logits_final, logit_masks + + def dec_forward( + self, + x_input, + x_lens, + x_attention_mask, + x_padding_mask, + y_input, + new_y_lens, + y_attention_mask, + y_padding_mask, + past=None, + last_3_tokens=False + ): + x_attn_mask = F.pad( + x_attention_mask, + (0, new_y_lens.max()), + value=True, + ) # x attn to all x, doesn't attn to any y, this follow figure 3 of the valle paper + y_attn_mask = F.pad( + y_attention_mask, + (x_lens.max(), 0), # y is padded at the front + value=False, + ) # y attn to all x, for y itself use lower triangle mask to ensure autoregressive + xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0) + + # merge key padding and attention masks + bsz, src_len = x_input.shape[0], x_lens.max() + new_y_lens.max() + xy_padding_mask = torch.concat([x_padding_mask, y_padding_mask], dim=1) + _xy_padding_mask = ( + xy_padding_mask.view(bsz, 1, 1, src_len) + .expand(-1, self.args.nhead, -1, -1) + .reshape(bsz * self.args.nhead, 1, src_len) + ) + xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) + + new_attn_mask = torch.zeros_like(xy_attn_mask) + new_attn_mask.masked_fill_(xy_attn_mask, float("-inf")) + xy_attn_mask = new_attn_mask + + xy_input = torch.cat([x_input, y_input], dim=1) + + if past == None: # do not use kvcache + out, _ = self.decoder((xy_input, None), mask=xy_attn_mask) + return out[:, x_lens.max():], None + else: # use kvcache + if past.ndim > 3: # uses kvcache, only need to pass the last tokens, this doesn't work with multi-span speech editing yet + if last_3_tokens: + xy_input = xy_input[:, -3:] + xy_attn_mask = xy_attn_mask[:, -3:] + else: + xy_input = xy_input[:, -1:] + xy_attn_mask = xy_attn_mask[:, -1:] + + out, present = self.decoder((xy_input, None), mask=xy_attn_mask, past=past) + if isinstance(out, tuple): # get rid of stage_embedding + out = out[0] + + if out.shape[1] > x_lens.max(): # the first pass, not kvcache yet + return out[:, x_lens.max():], present + else: # used kvcache + return out, present + + def forward(self, batch): + """ + Args: + x: + A 2-D tensor of shape (N, S). + x_lens: + A 1-D tensor of shape (N,). It contains the number of tokens in `x` + before padding. + y: + A 3-D tensor of shape (N, K, T). + where K is the number of codebooks + y_lens: + A 1-D tensor of shape (N,). It contains the number of tokens in `x` + before padding. + """ + x, x_lens, y, y_lens = batch["x"], batch["x_lens"], batch["y"], batch["y_lens"] + x = x[:, :x_lens.max()] # this deal with gradient accumulation, where x_lens.max() might not be longer than the length of the current slice of x + y = y[:, :y_lens.max()] + assert x.ndim == 2, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.ndim == 3 and y.shape[1] == self.args.n_codebooks, y.shape + assert y_lens.ndim == 1, y_lens.shape + # makes attention mask and padding mask for x + x_padding_mask = make_pad_mask(x_lens).to(x.device) + x_attention_mask = torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1).bool().to(x_padding_mask.device) + x_input = self.text_embedding(x) + x_input = self.text_positional_embedding(x_input) + y_input, new_y_lens, targets, y_padding_mask, y_attention_mask, mask_position, patterns = self.prepare_input_target(y, y_lens) + y_out = self.dec_forward( + x_input, + x_lens, + x_attention_mask, + x_padding_mask, + y_input, + new_y_lens, + y_attention_mask, + y_padding_mask + ) + y_out = y_out[0] # no kv-caching during training + assert y_out.shape == y_input.shape, f"y_out.shape: {y_out.shape}, y_input.shape: {y_input.shape}" # [B S D] + + logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card] + # take out the mask token (using mask_position and new_y_lens) and revert (using function provided by self.pattern) + assert logits.shape[1] == self.args.n_codebooks and logits.shape[3] == self.n_audio_tokens[0], logits.shape + + logits_use = self.remove_mask(logits, mask_position, new_y_lens) + + # revert the pattern shift for each logits section in each sample + logits_final, logit_masks = self.revert_pattern(patterns, logits_use) + assert logits_final[0][0].shape[0] == self.args.n_codebooks and logits_final[0][0].shape[2] == self.n_audio_tokens[0], f"it is: {logits_final[0][0].shape}, but should be [K, T, card]" + # testing + sample_to_test = 0 + assert len(logits_final[sample_to_test]) == len(targets[sample_to_test]), f"{len(logits_final[sample_to_test])}, {len(targets[sample_to_test])}" + temp = sum([logits_final[sample_to_test][i].shape[:-1] != targets[sample_to_test][i].shape for i in range(len(targets[sample_to_test]))]) + assert temp == 0, f"none equal positions: {temp}, total number of elements: {len(targets[sample_to_test])}" + + logit_masked = sum([(item==False).any() for cur_mask in logit_masks for item in cur_mask]) + assert logit_masked == 0, logit_masks + + logits = torch.cat([torch.cat(item, dim=1) for item in logits_final], dim=1) # [K, T1+T2+T3+..., card] + targets = torch.cat([torch.cat(item, dim=1) for item in targets], dim=1) # [K, T1+T2+T3+...] + assert targets.shape[0] == logits.shape[0], f"{targets.shape}, {logits.shape}" + loss = [] + ntokens = [] + top10acc = [] + for k, (logit, target) in enumerate(zip(logits, targets)): + loss.append(F.cross_entropy(logit, target, reduction='mean', weight=self.class_weight.data if self.args.eog_weight!=1 else None)) + top10acc.append(self.accuracy_metrics[k](logit.detach(), target)) + ntokens.append(len(logit)) + + all_ntokens = sum(ntokens) + if self.args.codebook_weight != None: + codebook_weight = eval(self.args.codebook_weight) + else: + codebook_weight = [1.] * self.args.n_codebooks + loss = sum([l*nt*cw for l, nt, cw in zip(loss, ntokens, codebook_weight)]) + top10acc_by_codebook = [t10a*nt for t10a, nt in zip(top10acc, ntokens)] + top10acc = sum(top10acc_by_codebook) + ntokens = torch.tensor(all_ntokens).to(logits.device) + + return { + "loss": loss, + "top10acc": top10acc, + "top10acc_by_codebook": top10acc_by_codebook, + "effective_ntoken": ntokens, + } + + def inference( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: torch.Tensor, + mask_interval: list[torch.Tensor], + top_k: int=-100, + top_p: float=1.0, + temperature: float=1.0, + stop_repetition: int=-1, + kvcache: int=1, + silence_tokens: list[int]=[1388,1898,131], + ) -> torch.Tensor: + """ + Args: + x: + A 2-D tensor of shape (1, L). + x_lens: + A 1-D tensor of shape (1,). It contains the number of tokens in `x` + before padding. + y: + A 3-D tensor of shape (1, T, K). + mask_interval: + a list of tensors of shape (M, 2). contains M mask_start and mask_end. list length is actually 1, because we only support single sample inference for now + top_k: (`optional`) int + The number of highest probability tokens to keep for top-k-filtering. Default to -100. + top_p: (`optional`) float + For Neucleus sampling + temperature: (`optional`) float + The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. + eog_coef: (`optional`) float + if 0, no change to eog token logits, otherwise, will adjust eog token logit based on the difference between acoustic token and phn token length + stop_repetition (`optional`) int + if not -1, will set the logits of a token that repeated this many times to be -100000, to avoid generating it again. This only apply to tokens from the first codebook + allowed_repeat_tokens (`optional`) list of ints + by inspecting the validation set, get a few tokens that indeed repeat a significant amount of time, and exclude those tokens from prevent repetition + ultimate_stop_repetition (`optional`) int + no matter that token it is, stop repetition once after this number + """ + assert x.ndim == 2, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.ndim == 3, y.shape + if self.args.special_first: + y = y + int(self.args.n_special) + y = y.transpose(2,1) # [1,T,K] -> [1,K,T] + assert y.shape[0] == 1 and y.shape[1] == self.args.n_codebooks, y.shape # there is no padding + assert mask_interval.shape == torch.Size((1, mask_interval.shape[1], 2)), mask_interval + + # make x attention mask and x_input + x_attention_mask = torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1).bool().to(x.device) + # x_attention_mask = torch.zeros(x.shape[1], x.shape[1]).bool().to(x.device) + x_input = self.text_embedding(x) + x_input = self.text_positional_embedding(x_input) + + # make initial y_input + + # make mask_interval and non_mask_interval + y_len = y.shape[2] + y_lens = torch.LongTensor([y_len]).to(y.device) + mask_interval = mask_interval[0] + starts = [item[0].item() for item in mask_interval] + [y_len] + ends = [0] + [item[1].item() for item in mask_interval] + mask_intervals = [[ + (item[0].item(), item[1].item()) for item in mask_interval + ]] # a werid name change, mask_interval is input, now is mask_intervals, with one more dimension + non_mask_intervals = [[ + (ns, ne) for ns, ne in zip(ends, starts) + ]] + + # rearrange y + # will add have EOG in each section (SOG will be generated by the pattern class) + # but mask can be inserted later after we have shifted the input + # y could be rearranged in this way: + # [ + # [tensor[4, 12], tensor[4, 45], tensor[4, 102], tensor[4, 32]], tensor[4, 22]], + # [tensor[4, 44], tensor[4, 56], tensor[4, 19]], + # ... + # ] + # for the first list of tensors (4 tensors), first 3 tensors are non_masked part, last 2 are masked part. + # NOTE #non_masked_part = #masked_part + 1 + rearranged_y = self.rearrange(y, non_mask_intervals, mask_intervals) + assert rearranged_y[0][0].shape[0] == self.args.n_codebooks, rearranged_y[0][0].shape + + # shift each element of y + # next we need to apply pattern shifting to each tensor, after which, we'll replace the starting tokens of each section with a token that's different from the special padding token + # [ + # [empty, 1, 2, 3, eog, empty, empty, empty], + # [empty, empty, 1, 2, 3, eog, empty, empty], + # [empty, empty, empty, 1, 2, 3, eog, empty], + # [empty, empty, empty, empty, 1, 2, 3, eog] + # ] + shifted_y, patterns = self.shift(rearranged_y) # each element [K S], patterns is not used, as we directly use the original input y + assert shifted_y[0][0].shape[0] == self.args.n_codebooks, shifted_y[0][0].shape + + # insert mask token at the intersction of each tensor, but *actually inserted eog as place holder* + # the position of inserted mask is also recorded + # and the mask_value, the index of the mask emb is recorded + inserted_y, mask_position, mask_value = self.insert_mask(shifted_y) + assert inserted_y[0][0].shape[0] == self.args.n_codebooks, inserted_y[0][0].shape[0] + assert inserted_y[0][1].shape == torch.Size((self.args.n_codebooks, 1)), f"this should be a mask, so should have shape {(self.args.n_codebooks, 1)}, but it's {inserted_y[0][1].shape}" + + # then concat tensors that belong to the same sample (in order) then get the length of each sample, and then stack them in batch dimension, pad them with pad_token + cated_y, new_y_lens = self.cat_y(inserted_y, mask_position, y_lens) # KTB + assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], len(inserted_y))) + assert not (cated_y == self.args.audio_pad_token).any(), cated_y + + ### NOTE this is different from forward, as we will remove the masked tokens + ### say there are two masked region + ### the cated_y should be like + ### [empty a a a a mask0 empty b b b mask1 empty c c mask0 empty] + ### which means we need to take the part after the last empty out + num_mask = len(mask_position[0])//2 + assert num_mask == len(mask_position[0])/2, mask_position + cated_y = cated_y[:, :mask_position[0][num_mask]+2] # of shape [K,T,B] + # logging.info(f"mask_position[0][num_mask]+2: {mask_position[0][num_mask]+2}") + more_mask_value = mask_value[0][num_mask+1:] # NOTE this will be used in the generation loop for reference for inserting mask embedding + new_y_lens[0] = mask_position[0][num_mask]+2 + mask_position[0] = mask_position[0][:num_mask+1] + assert mask_position[0][num_mask]+2 == cated_y.shape[1], f"num_mask: {num_mask}, mask_position: {mask_position}, cated_y.shape: {cated_y.shape}" + + # embed: remember to separately embed the mask tokens + embedded_y = self.embed_y(cated_y, mask_position, [mask_value[0][:num_mask+1]]) #BTD + # assert embedded_y.shape == torch.Size((y.shape[0], max(new_y_lens), self.args.d_model)), embedded_y.shape + + # positional embedding + y_input = self.audio_positional_embedding(embedded_y) + + # make attention mask and padding mask + y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device) + # y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device) + + x_padding_mask = torch.full((1,x_lens[0]), False).to(x.device) + y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device) + + + codebook_eog = [False] * self.args.n_codebooks + generated = [] # doesn't contain any empty_token, contains eog + cur_generated = [] + # say 0 is empty, 4 is eog + # tensor([[ 1, 2, 3, 4, 0, 0], + # [ 0, 1, 2, 3, 4, 0], + # [ 0, 0, 1, 2, 3, 4]]) + num_gen = [] + cur_num_gen = 0 + ##################### silence repetition handling ##################### + ##################### silence repetition handling ##################### + logging.info(f"silence tokens: {silence_tokens}, note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default") + consec_silence_count = 0 + prev_token = None + ##################### silence repetition handling ##################### + ##################### silence repetition handling ##################### + # prepare the cache placeholder + # n_layers, 2, bsz, num_heads, src_len, head_dim + past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None + # handle multi-span kv-cache + new_masked_span = False + + def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_token, consec_silence_count, stop_repetition, silence_tokens, cur_num_gen): + if n_eog == 0: + logits_adjust = logits + for jj in range(1,self.args.n_codebooks): + logits_adjust[jj][self.args.eog] = -10000 + logits_adjust[jj][self.args.empty_token] = -10000 + ##################### silence repetition handling ##################### + if stop_repetition > 0 and prev_token in silence_tokens and consec_silence_count > stop_repetition: + if logits_adjust[0, prev_token] < 0: + logits_adjust[0, prev_token] = logits_adjust[0, prev_token] * (consec_silence_count - (stop_repetition-1)) + else: + logits_adjust[0, prev_token] = logits_adjust[0, prev_token] / (consec_silence_count - (stop_repetition-1)) + ##################### silence repetition handling ##################### + if type(logits_adjust) == list: + samples_list= [] + for logit in logits_adjust: + # print(logit) + # print(logit.shape) + cur_sample = topk_sampling( + logit.unsqueeze(0), top_k=top_k, top_p=top_p, temperature=temperature + ) # [1, 1] + samples_list.append(cur_sample) + samples = torch.cat(samples_list, dim=0) # [K, 1] + else: + samples = topk_sampling( + logits_adjust, top_k=top_k, top_p=top_p, temperature=temperature + ) # [K, 1] + assert samples.shape == torch.Size((self.args.n_codebooks, 1)), f"samples.shape: {samples.shape}" + if cur_num_gen < self.args.n_codebooks-1: + for jj in range(1, self.args.n_codebooks - cur_num_gen): + samples[-jj, 0] = self.args.empty_token + + if ( + samples[0,0] == self.args.eog or torch.argmax(logits[0], dim=-1) == self.args.eog or y_input.shape[1] > x_lens[0] * 10 + ): # last one means y is already too long, shouldn't happen, but put it here + samples[0,0] = self.args.eog + codebook_eog[0] = True + ##################### silence repetition handling ##################### + ##################### silence repetition handling ##################### + if samples[0,0] in silence_tokens and samples[0,0] == prev_token: + consec_silence_count += 1 + else: + consec_silence_count = 0 + prev_token = samples[0,0] + ##################### silence repetition handling ##################### + ##################### silence repetition handling ##################### + return samples, codebook_eog, prev_token, consec_silence_count + else: + assert sum(codebook_eog[i] for i in range(n_eog)) == n_eog, f"codebook_eog: {codebook_eog}, but n_eog: {n_eog}" + logits_adjust = logits + for jj in range(n_eog+1,self.args.n_codebooks): + logits_adjust[jj][self.args.eog] = -10000 + logits_adjust[jj][self.args.empty_token] = -10000 + if type(logits_adjust) == list: + samples_list= [] + for logit in logits_adjust: + cur_sample = topk_sampling( + logit.unsqueeze(0), top_k=top_k, top_p=top_p, temperature=temperature + ) # [1, 1] + samples_list.append(cur_sample) + samples = torch.cat(samples_list, dim=0) # [K, 1] + else: + samples = topk_sampling( + logits_adjust, top_k=top_k, top_p=top_p, temperature=temperature + ) # [K, 1] + for jj in range(n_eog): + samples[jj, 0] = self.args.empty_token + samples[n_eog, 0] = self.args.eog + codebook_eog[n_eog] = True + return samples, codebook_eog, prev_token, consec_silence_count + + while True: + y_out, present = self.dec_forward( + x_input, + x_lens, + x_attention_mask, + x_padding_mask, + y_input, + new_y_lens, + y_attention_mask, + y_padding_mask, + past=past, + last_3_tokens = new_masked_span + ) + if new_masked_span: + new_masked_span = False + + if past != None: + past = torch.cat([past, present.to(past.dtype)], dim=-2) if past.ndim > 3 else present.to(past.dtype) + + y_out = y_out[:, -1:] # only take the last one + + logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card], B==S==1, so [1 K 1 card] + logits = logits.squeeze(0).squeeze(1) # [K card] + assert logits.shape == torch.Size((self.args.n_codebooks, self.n_audio_tokens[0])), f"{logits.shape}" + + n_eog = sum(codebook_eog) + assert n_eog < self.args.n_codebooks + if self.args.eos > 0: # eos stands for end-of-sentence, which shouldn't be used as we are doing speech editing + for jj in range(self.args.n_codebooks): + logits[jj][self.args.eos] = -10000. + # need to use a helper function to hand different n_eog cases + samples, codebook_eog, prev_token, consec_silence_count = sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_token, consec_silence_count, stop_repetition, silence_tokens, cur_num_gen) + cur_num_gen += 1 + cur_generated.append(samples.squeeze(-1)) # [K,1] -> [K] + # get samples_emb + samples_emb = torch.stack([self.audio_embedding[k](samples[k]) for k in range(self.args.n_codebooks)], dim=0) # [K,1,D] + samples_emb = samples_emb.sum(dim=0,keepdim=True) # [1,1,D] + + if sum(codebook_eog) == self.args.n_codebooks: # generation for the current span is done + # re-init + codebook_eog = [False] * self.args.n_codebooks + num_gen.append(cur_num_gen) + cur_num_gen = 0 + generated.append(cur_generated) + cur_generated = [] + + # if the current mask span is the last span, then all done + # else + # append the next mask token and the four empty tokens to start the next generation + if len(more_mask_value) > 0: + next_mask_ind = more_mask_value.pop(0) + mask_emb = self.mask_embedding[next_mask_ind].unsqueeze(0).unsqueeze(0) # [1,1,D] + assert mask_emb.shape == torch.Size((1,1,self.args.d_model)), mask_emb.shape + empty_token = torch.LongTensor([self.args.empty_token]).to(y.device) + empty_emb = torch.stack([ + self.audio_embedding[k](empty_token) for k in range(self.args.n_codebooks)], dim=0 + ).sum(dim=0, keepdim=True) # [1,1,D] + assert empty_emb.shape == torch.Size((1,1,self.args.d_model)), empty_emb.shape + extra_emb = torch.cat([mask_emb, empty_emb], dim=1) # [1,2,D] + samples_emb = torch.cat([samples_emb, extra_emb], dim=1) # [1,3,D] # prev_last_token, mask_token, empty token + assert samples_emb.shape == torch.Size((1,3,self.args.d_model)), f"samples_emb.shape: {samples_emb.shape}" + ##################### silence repetition handling ##################### + ##################### silence repetition handling ##################### + consec_silence_count = 0 + prev_token = None + ##################### silence repetition handling ##################### + ##################### silence repetition handling ##################### + + # handling kv-caching for multi-span editing + new_masked_span = True + else: + break + else: + assert samples_emb.shape == torch.Size((1,1,self.args.d_model)), f"samples_emb.shape: {samples_emb.shape}" + + embedded_y = torch.cat([embedded_y, samples_emb], dim=1) + # positional embedding + y_input = self.audio_positional_embedding(embedded_y) # [B T D] + # make attention mask and padding mask + y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device) + new_y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device) + y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device) + + assert len(generated) == num_mask, f"len(generated): {len(generated)}, num_mask: {num_mask}" + + # # combine non_masked_span with generated spans + # first need to shift the generated part back + flatten_gen = [] + for l, orig_span in enumerate(generated): + span = torch.stack(orig_span, dim=0) # [T K] + span = span.transpose(1,0) # [K, T] + assert span.shape[0] == self.args.n_codebooks, span.shape + unshifted_span = [] + for j, s in enumerate(span): + start_from = j + end_at = - (self.args.n_codebooks - start_from) + unshifted_span.append(s[start_from:end_at]) + unshifted_span = torch.stack(unshifted_span, dim=0) + + assert unshifted_span.shape[1] == num_gen[l] - self.args.n_codebooks, f"len(unshifted_spans[0]): {len(unshifted_span[0])}, num_gen[l]: {num_gen[l]}" + flatten_gen.append(unshifted_span) + # logging.info(f"unshfited_span: {unshifted_span.shape}") + # raise + assert len(non_mask_intervals[0]) - 1 == len(flatten_gen), f"len(non_mask_intervals[0]): {len(non_mask_intervals[0])}, len(flatten_gen): {len(flatten_gen)}" + res = [] + for orig_interval, gen in zip(non_mask_intervals[0], flatten_gen): + res.append(y[0, :, orig_interval[0]:orig_interval[1]]) + res.append(gen) + res.append(y[0, :, non_mask_intervals[0][-1][0]:non_mask_intervals[0][-1][1]]) + res = torch.cat(res, dim=1).unsqueeze(0) # [K,new_T] -> [1, K, new_T] + + expected_y_len = y_len - sum([item[1] - item[0] for item in mask_intervals[0]]) + sum([item - self.args.n_codebooks for item in num_gen]) + assert res.shape == torch.Size((1, self.args.n_codebooks, expected_y_len)), f"res.shape: {res.shape}, expected_y_len: {expected_y_len}. y_len - sum([item[1] - item[0] for item in mask_interval]) + sum([item - self.args.n_codebooks for item in num_gen]): {y_len}-{sum([item[1] - item[0] for item in mask_interval])} + {sum([item - self.args.n_codebooks for item in num_gen])}" + + if self.args.special_first: + res = res - int(self.args.n_special) + + return res + + def inference_tts( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: torch.Tensor, + top_k: int=-100, + top_p: float=1.0, + temperature: float=1.0, + stop_repetition: int=3, + kvcache: int=1, + silence_tokens: list[int]=[1388,1898,131], + *kargs + ) -> torch.Tensor: + """ + different from inference_tts, this implementation uses kvcache, which should have significant speed up + Args: + x: + A 2-D tensor of shape (1, L). + x_lens: + A 1-D tensor of shape (1,). It contains the number of tokens in `x` + before padding. + y: + A 3-D tensor of shape (1, T, K). + top_k: (`optional`) int + The number of highest probability tokens to keep for top-k-filtering. Default to -100. + top_p: (`optional`) float + For Neucleus sampling + temperature: (`optional`) float + The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. + """ + eog_inference = self.args.eos if self.args.eos>0 else self.args.eog + assert x.ndim == 2, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.ndim == 3, y.shape + if self.args.special_first: + y = y + int(self.args.n_special) + y = y.transpose(2,1) # [1,T,K] -> [1,K,T] + assert y.shape[0] == 1 and y.shape[1] == self.args.n_codebooks, y.shape # there is no padding + + # make x attention mask and x_input + x_attention_mask = torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1).bool().to(x.device) + # x_attention_mask = torch.zeros(x.shape[1], x.shape[1]).bool().to(x.device) + x_input = self.text_embedding(x) + x_input = self.text_positional_embedding(x_input) + + y_len = y.shape[2] + y_lens = torch.LongTensor([y_len]).to(y.device) + + # rearrange y, we don't add eog to the end, this doesn't actually do anything in the tts scenario + rearranged_y = [[y[0]]] + assert rearranged_y[0][0].shape[0] == self.args.n_codebooks, rearranged_y[0][0].shape + + # shift y to create the delayed pattern + shifted_y, patterns = self.shift(rearranged_y) # each element [K S], patterns is not used, as we directly use the original input y + assert shifted_y[0][0].shape[0] == self.args.n_codebooks, shifted_y[0][0].shape + assert len(shifted_y[0]) == 1, len(shifted_y[0]) + + # below is different from forward or inference + # where we cut this shifted part + shifted_y[0][0] = shifted_y[0][0][:, :-(self.args.n_codebooks-1)] + assert not (shifted_y[0][0][self.args.n_codebooks:] == self.args.empty_token).any() and not (shifted_y[0][0][self.args.n_codebooks:] == self.args.eog).any(), shifted_y[0][0] + + # next section in inference is insert mask at the intersection of each tensor in a sample, but we don't need to do that + # next section is concate tensors of each sample to one tensor, which we also don't need + cated_y = shifted_y[0][0].unsqueeze(-1) #[K,S]->[K,S,B] + new_y_lens = torch.LongTensor([cated_y.shape[1]]).to(cated_y.device) + assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], 1)) + assert not (cated_y == self.args.audio_pad_token).any(), cated_y + + # replace tokens in y with the embeddings, add sum codebooks up + embedded_y = torch.stack([self.audio_embedding[k](cated_y[k]) for k in range(self.args.n_codebooks)], dim=0) # [K, S, B, D] + assert embedded_y.shape[0] == self.args.n_codebooks, embedded_y.shape + assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape + embedded_y = embedded_y.sum(dim=0) # [K,S,B,D]->[S,B,D] + embedded_y = embedded_y.transpose(1,0) # [S,B,D]->[B,S,D] + + # positional embedding + y_input = self.audio_positional_embedding(embedded_y) + + # make attention mask and padding mask + y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device) + + x_padding_mask = torch.full((1,x_lens[0]), False).to(x.device) + y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device) + + # entering the generation stage + # starting from line 708 + codebook_eog = [False] * self.args.n_codebooks + generated = [] # doesn't contain any empty token, contain eog + cur_generated = [] + # say 0 is empty, 4 is eog + # tensor([[ 1, 2, 3, 4, 0, 0], + # [ 0, 1, 2, 3, 4, 0], + # [ 0, 0, 1, 2, 3, 4]]) + num_gen = [] + cur_num_gen = 0 + ##################### silence repetition handling ##################### + ##################### silence repetition handling ##################### + logging.info(f"silence tokens: {silence_tokens}, note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default") + consec_silence_count = 0 + prev_token = None + ##################### silence repetition handling ##################### + ##################### silence repetition handling ##################### + + # prepare the cache placeholder + # n_layers, 2, bsz, num_heads, src_len, head_dim + past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None + # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}") + # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}") + # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}") + def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_token, consec_silence_count, stop_repetition, silence_tokens, cur_num_gen): + if n_eog == 0: + logits_adjust = logits + for jj in range(1,self.args.n_codebooks): + logits_adjust[jj][eog_inference] = -10000 + logits_adjust[jj][self.args.empty_token] = -10000 + ##################### silence repetition handling ##################### + if stop_repetition > 0 and prev_token in silence_tokens and consec_silence_count > stop_repetition: + if logits_adjust[0, prev_token] < 0: + logits_adjust[0, prev_token] = logits_adjust[0, prev_token] * (consec_silence_count - (stop_repetition-1)) + else: + logits_adjust[0, prev_token] = logits_adjust[0, prev_token] / (consec_silence_count - (stop_repetition-1)) + ##################### silence repetition handling ##################### + samples = topk_sampling( + logits_adjust, top_k=top_k, top_p=top_p, temperature=temperature + ) # [K, 1] + assert samples.shape == torch.Size((self.args.n_codebooks, 1)), f"samples.shape: {samples.shape}" + if cur_num_gen < self.args.n_codebooks-1: + for jj in range(1, self.args.n_codebooks - cur_num_gen): + samples[-jj, 0] = self.args.empty_token + + if ( + samples[0,0] == eog_inference or torch.argmax(logits[0], dim=-1) == eog_inference or y_input.shape[1] > x_lens[0] * (self.args.encodec_sr//5) + ): # last one means y is already too long, shouldn't happen, but put it here + samples[0,0] = eog_inference + codebook_eog[0] = True + ##################### silence repetition handling ##################### + if samples[0,0] in silence_tokens and samples[0,0] == prev_token: + consec_silence_count += 1 + else: + consec_silence_count = 0 + prev_token = samples[0,0] + ##################### silence repetition handling ##################### + return samples, codebook_eog, prev_token, consec_silence_count + else: + assert sum(codebook_eog[i] for i in range(n_eog)) == n_eog, f"codebook_eog: {codebook_eog}, but n_eog: {n_eog}" + logits_adjust = logits + for jj in range(n_eog+1,self.args.n_codebooks): + logits_adjust[jj][eog_inference] = -10000 + logits_adjust[jj][self.args.empty_token] = -10000 + samples = topk_sampling( + logits_adjust, top_k=top_k, top_p=top_p, temperature=temperature + ) # [K, 1] + for jj in range(n_eog): + samples[jj, 0] = self.args.empty_token + samples[n_eog, 0] = eog_inference + codebook_eog[n_eog] = True + return samples, codebook_eog, prev_token, consec_silence_count + while True: + y_out, present = self.dec_forward( + x_input, + x_lens, + x_attention_mask, + x_padding_mask, + y_input, + new_y_lens, + y_attention_mask, + y_padding_mask, + past=past + ) + if past != None: + past = torch.cat([past, present.to(past.dtype)], dim=-2) if past.ndim > 3 else present.to(past.dtype) + + + y_out = y_out[:, -1:] # only take the last token + logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card], B==S==1, so [1 K 1 card] + logits = logits.squeeze(0).squeeze(1) # [K card] + assert logits.shape == torch.Size((self.args.n_codebooks, self.n_audio_tokens[0])), f"{logits.shape}" + + n_eog = sum(codebook_eog) + assert n_eog < self.args.n_codebooks + if self.args.eos > 0: # if we are using end-of-sentence token (which is used by default), eog shouldn't be used here, as there is no masked spans + for jj in range(self.args.n_codebooks): + logits[jj][self.args.eog] = -10000. + + samples, codebook_eog, prev_token, consec_silence_count = sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_token, consec_silence_count, stop_repetition, silence_tokens, cur_num_gen) + + cur_num_gen += 1 + cur_generated.append(samples.squeeze(-1)) # [K,1] -> [K] + + # samples.shape is [K,1] + # ge samples_emb + samples_emb = torch.stack([self.audio_embedding[k](samples[k]) for k in range(self.args.n_codebooks)], dim=0) # [K,1,D] + samples_emb = samples_emb.sum(dim=0,keepdim=True) # [1,1,D] + + if sum(codebook_eog) == self.args.n_codebooks: # generation for the current span is done + codebook_eog = [False] * self.args.n_codebooks + num_gen.append(cur_num_gen) + cur_num_gen = 0 + generated.append(cur_generated) + cur_generated = [] + break + else: + assert samples_emb.shape == torch.Size((1,1,self.args.d_model)), f"samples_emb.shape: {samples_emb.shape}" + + embedded_y = torch.cat([embedded_y, samples_emb], dim=1) + y_input = self.audio_positional_embedding(embedded_y) # [B T D] + # make attention mask and padding mask + y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device) + new_y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device) + y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device) + + assert len(generated) == 1, f"len(generated): {len(generated)}" + + # revert the pattern + flatten_gen = [] + for l, orig_span in enumerate(generated): + span = torch.stack(orig_span, dim=0) # [T, K] + span = span.transpose(1,0) # [K, T] + assert span.shape[0] == self.args.n_codebooks, span.shape + unshifted_span = [] + for j, s in enumerate(span): + start_from = j + end_at = - (self.args.n_codebooks - start_from) + unshifted_span.append(s[start_from:end_at]) + unshifted_span = torch.stack(unshifted_span, dim=0) + + assert unshifted_span.shape[1] == num_gen[l] - self.args.n_codebooks, f"len(unshifted_spans[0]): {len(unshifted_span[0])}, num_gen[l]: {num_gen[l]}" + + flatten_gen.append(unshifted_span) + assert len(flatten_gen) == 1, len(flatten_gen) + + # combine + res = [y[0], flatten_gen[0]] + res = torch.cat(res, dim=1).unsqueeze(0) # [K, new_t] -> [1, K, new_T] + + expected_y_len = y_len + sum([item - self.args.n_codebooks for item in num_gen]) + assert res.shape == torch.Size((1, self.args.n_codebooks, expected_y_len)), f"res.shape: {res.shape}, expected_y_len: {expected_y_len}. y_len + sum([item - self.args.n_codebooks for item in num_gen]): {y_len} + {sum([item - self.args.n_codebooks for item in num_gen])}" + + if self.args.special_first: + res = res - int(self.args.n_special) + flatten_gen = flatten_gen - int(self.args.n_special) + + return res, flatten_gen[0].unsqueeze(0) + + + def inference_tts_batch( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: torch.Tensor, + top_k: int=-100, + top_p: float=1.0, + temperature: float=1.0, + stop_repetition: int=3, + kvcache: int=1, + batch_size: int=5, + silence_tokens: list[int]=[1388,1898,131], + *kargs + ) -> torch.Tensor: + """ + have a batch size when forward passing, but they are equivalant to same example but different random seed, therefore as long as one example generated eog, we can drop all other samlpes + different from inference_tts, this implementation uses kvcache, which should have significant speed up + Args: + x: + A 2-D tensor of shape (1, L). + x_lens: + A 1-D tensor of shape (1,). It contains the number of tokens in `x` + before padding. + y: + A 3-D tensor of shape (1, T, K). + top_k: (`optional`) int + The number of highest probability tokens to keep for top-k-filtering. Default to -100. + top_p: (`optional`) float + For Neucleus sampling + temperature: (`optional`) float + The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. + """ + eog_inference = self.args.eos if self.args.eos>0 else self.args.eog + assert x.ndim == 2, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.ndim == 3, y.shape + if self.args.special_first: + y = y + int(self.args.n_special) + y = y.transpose(2,1) # [1,T,K] -> [1,K,T] + assert y.shape[0] == 1 and y.shape[1] == self.args.n_codebooks, y.shape # there is no padding + + # make x attention mask and x_input + x_attention_mask = torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1).bool().to(x.device) + # x_attention_mask = torch.zeros(x.shape[1], x.shape[1]).bool().to(x.device) + x_input = self.text_embedding(x) + x_input = self.text_positional_embedding(x_input) + + y_len = y.shape[2] + y_lens = torch.LongTensor([y_len]).to(y.device) + + # rearrange y, we don't add eog to the end, this doesn't actually do anything in the tts scenario + rearranged_y = [[y[0]]] + assert rearranged_y[0][0].shape[0] == self.args.n_codebooks, rearranged_y[0][0].shape + + # shift y to create the delayed pattern + shifted_y, patterns = self.shift(rearranged_y) # each element [K S], patterns is not used, as we directly use the original input y + assert shifted_y[0][0].shape[0] == self.args.n_codebooks, shifted_y[0][0].shape + assert len(shifted_y[0]) == 1, len(shifted_y[0]) + + # below is different from forward or inference + # where we cut this shifted part + shifted_y[0][0] = shifted_y[0][0][:, :-(self.args.n_codebooks-1)] + assert not (shifted_y[0][0][self.args.n_codebooks:] == self.args.empty_token).any() and not (shifted_y[0][0][self.args.n_codebooks:] == self.args.eog).any(), shifted_y[0][0] + + # next section in inference is insert mask at the intersection of each tensor in a sample, but we don't need to do that + # next section is concate tensors of each sample to one tensor, which we also don't need + cated_y = shifted_y[0][0].unsqueeze(-1) #[K,S]->[K,S,B] + new_y_lens = torch.LongTensor([cated_y.shape[1]]).to(cated_y.device) + assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], 1)) + assert not (cated_y == self.args.audio_pad_token).any(), cated_y + + # replace tokens in y with the embeddings, add sum codebooks up + embedded_y = torch.stack([self.audio_embedding[k](cated_y[k]) for k in range(self.args.n_codebooks)], dim=0) # [K, S, B, D] + assert embedded_y.shape[0] == self.args.n_codebooks, embedded_y.shape + assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape + embedded_y = embedded_y.sum(dim=0) # [K,S,B,D]->[S,B,D] + embedded_y = embedded_y.transpose(1,0) # [S,B,D]->[B,S,D] + + # positional embedding + y_input = self.audio_positional_embedding(embedded_y) + + # make attention mask and padding mask + y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device) + + x_padding_mask = torch.full((1,x_lens[0]), False).to(x.device) + y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device) + + # entering the generation stage + # starting from line 708 + codebook_eog = [False] * self.args.n_codebooks + generated = [] # doesn't contain any empty token, contain eog + cur_generated = [[] for _ in range(batch_size)] + # say 0 is empty, 4 is eog + # tensor([[ 1, 2, 3, 4, 0, 0], + # [ 0, 1, 2, 3, 4, 0], + # [ 0, 0, 1, 2, 3, 4]]) + num_gen = [] + cur_num_gen = 0 + ##################### silence repetition handling ##################### + ##################### silence repetition handling ##################### + logging.info(f"silence tokens: {silence_tokens}, note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default") + consec_silence_counts = [0 for _ in range(batch_size)] + prev_tokens = [None for _ in range(batch_size)] + ##################### silence repetition handling ##################### + ##################### silence repetition handling ##################### + + # prepare the cache placeholder + # n_layers, 2, bsz, num_heads, src_len, head_dim + past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None + # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}") + # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}") + # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}") + keep = None # NOTE: this very important, tells which sample to keep + def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_tokens, consec_silence_counts, stop_repetition, silence_tokens, cur_num_gen, keep): + if n_eog == 0: + logits_adjust = logits + for jj in range(1,self.args.n_codebooks): + logits_adjust[:,jj,eog_inference] = -10000 + logits_adjust[:,jj,self.args.empty_token] = -10000 + ##################### silence repetition handling ##################### + for b in range(batch_size): + prev_token = prev_tokens[b] + consec_silence_count = consec_silence_counts[b] + if stop_repetition > 0 and prev_token in silence_tokens and consec_silence_count > stop_repetition: + if logits_adjust[b, 0, prev_token] < 0: + logits_adjust[b, 0, prev_token] = logits_adjust[b, 0, prev_token] * (consec_silence_count - (stop_repetition-1)) + else: + logits_adjust[b, 0, prev_token] = logits_adjust[b, 0, prev_token] / (consec_silence_count - (stop_repetition-1)) + ##################### silence repetition handling ##################### + samples = topk_sampling( + logits_adjust.reshape(batch_size * self.args.n_codebooks, logits_adjust.shape[-1]), top_k=top_k, top_p=top_p, temperature=temperature + ) # [B*K, 1] + samples = samples.reshape(batch_size, self.args.n_codebooks, 1) + assert samples.shape == torch.Size((batch_size, self.args.n_codebooks, 1)), f"samples.shape: {samples.shape}" + for b in range(batch_size): + if cur_num_gen < self.args.n_codebooks-1: + for jj in range(1, self.args.n_codebooks - cur_num_gen): + samples[b, -jj, 0] = self.args.empty_token + + if ( + samples[b,0,0] == eog_inference or torch.argmax(logits[b,0], dim=-1) == eog_inference or y_input.shape[1] > x_lens[b] * (self.args.encodec_sr//5) + ): # last one means y is already too long, shouldn't happen, but put it here + samples[b,0,0] = eog_inference + codebook_eog[0] = True + keep = b # NOTE keep is a very important variable, we only return this one, note that if eog shows up in two samples, keep will be overwritten by the later one (or the last one) + ##################### silence repetition handling ##################### + if samples[b,0,0] in silence_tokens and samples[b,0,0] == prev_tokens[b]: + consec_silence_counts[b] += 1 + else: + consec_silence_counts[b] = 0 + prev_tokens[b] = samples[b,0,0] + ##################### silence repetition handling ##################### + return samples, codebook_eog, prev_tokens, consec_silence_counts, keep + else: + assert sum(codebook_eog[i] for i in range(n_eog)) == n_eog, f"codebook_eog: {codebook_eog}, but n_eog: {n_eog}" + logits_adjust = logits + for jj in range(n_eog+1,self.args.n_codebooks): + logits_adjust[:,jj,eog_inference] = -10000 + logits_adjust[:,jj,self.args.empty_token] = -10000 + samples = topk_sampling( + logits_adjust.reshape(batch_size * self.args.n_codebooks, logits_adjust.shape[-1]), top_k=top_k, top_p=top_p, temperature=temperature + ) # [B, K, 1] + samples = samples.reshape(batch_size, self.args.n_codebooks, 1) + for jj in range(n_eog): + samples[keep, jj, 0] = self.args.empty_token + samples[keep, n_eog, 0] = eog_inference + codebook_eog[n_eog] = True + return samples, codebook_eog, prev_tokens, consec_silence_counts, keep + while True: + # if cur_num_gen > 0, should have everything in kvcache, so only pass in the last token + # in the first generation step, we repeat each tensor to make their first dimension of length the batch size + if cur_num_gen == 0: + assert x_input.ndim == 3 and x_input.shape[0] == 1, x_input.shape + assert x_padding_mask.ndim == 2 and x_padding_mask.shape[0] == 1, x_padding_mask.shape + assert y_input.ndim == 3 and y_input.shape[0] == 1 and y_input.shape[1] == new_y_lens[0], y_input.shape + assert embedded_y.ndim == 3 and embedded_y.shape[0] == 1 and embedded_y.shape[1] == new_y_lens[0], embedded_y.shape + x_input = x_input.repeat(batch_size, 1, 1) + x_lens = x_lens.repeat(batch_size) + # x_attention_mask = x_attention_mask.repeat(batch_size, 1, 1) # no need to work with attention mask, it doesn't contain batch dimension + x_padding_mask = x_padding_mask.repeat(batch_size, 1) + y_input = y_input.repeat(batch_size, 1, 1) + new_y_lens = new_y_lens.repeat(batch_size) + # y_attention_mask = y_attention_mask.repeat(batch_size, 1, 1) # no need to work with attention mask, it doesn't contain batch dimension + y_padding_mask = y_padding_mask.repeat(batch_size, 1) + embedded_y = embedded_y.repeat(batch_size, 1, 1) # will be used to concat with newly generated token embedding + past = past.repeat(1, 1, batch_size) if past != None else None + else: + assert x_input.shape[0] == batch_size and x_padding_mask.shape[0] == batch_size and y_input.shape[0] == batch_size and new_y_lens.shape[0] == batch_size, f"x_input.shape: {x_input.shape}, x_padding_mask.shape: {x_padding_mask.shape}, y_input.shape: {y_input.shape}, new_y_lens.shape: {new_y_lens.shape}" + y_out, present = self.dec_forward( + x_input, + x_lens, + x_attention_mask, + x_padding_mask, + y_input, + new_y_lens, + y_attention_mask, + y_padding_mask, + past=past + ) + if past != None: + past = torch.cat([past, present.to(past.dtype)], dim=-2) if past.ndim > 3 else present.to(past.dtype) + + # if no eog emerges, y_out should have batch size of batch_size + if sum(codebook_eog) == 0: + assert y_out.shape[0] == batch_size and y_out.ndim == 3, y_out.shape + y_out = y_out[:, -1:] # only take the last token + logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card], S==1, so [B K 1 card] + logits = logits.squeeze(2) # [B K card] + assert logits.shape == torch.Size((batch_size, self.args.n_codebooks, self.n_audio_tokens[0])), f"{logits.shape}" + + n_eog = sum(codebook_eog) + if self.args.eos > 0: + for jj in range(self.args.n_codebooks): + logits[:,jj,self.args.eog] = -10000. + samples, codebook_eog, prev_tokens, consec_silence_counts, keep = sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_tokens, consec_silence_counts, stop_repetition, silence_tokens, cur_num_gen, keep) + + cur_num_gen += 1 + if sum(codebook_eog) == 0: # no eog yet, keep batch_size of samples + assert keep == None + for b in range(batch_size): + cur_generated[b].append(samples[b].squeeze(-1)) + elif sum(codebook_eog) == 1: # the first eog just showed up in this step + assert keep != None + cur_generated = cur_generated[keep] + cur_generated.append(samples[keep].squeeze(-1)) + else: # we are generating the rest eogs for the 'keep' sample + cur_generated.append(samples[keep].squeeze(-1)) + + # samples.shape is [K,1] + # ge samples_emb + samples_emb = torch.stack([self.audio_embedding[k](samples[:, k]) for k in range(self.args.n_codebooks)], dim=1) # [B, K,1,D] + assert samples_emb.shape == torch.Size([batch_size, self.args.n_codebooks, 1, self.args.d_model]) + samples_emb = samples_emb.sum(dim=1,keepdim=False) # [B,1,D] + if sum(codebook_eog) == self.args.n_codebooks: # generation for the current span is done + codebook_eog = [False] * self.args.n_codebooks + num_gen.append(cur_num_gen) + cur_num_gen = 0 + generated.append(cur_generated) + cur_generated = [[] for _ in range(batch_size)] + break + else: + assert samples_emb.shape == torch.Size((batch_size,1,self.args.d_model)), f"samples_emb.shape: {samples_emb.shape}" + + embedded_y = torch.cat([embedded_y, samples_emb], dim=1) + y_input = self.audio_positional_embedding(embedded_y) # [B T D] + # make attention mask and padding mask + y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device) + new_y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device).repeat(batch_size) + y_padding_mask = torch.full((batch_size,new_y_lens[0]), False).to(y.device) + + assert len(generated) == 1, f"len(generated): {len(generated)}" + + # revert the pattern + flatten_gen = [] + for l, orig_span in enumerate(generated): + span = torch.stack(orig_span, dim=0) # [T, K] + span = span.transpose(1,0) # [K, T] + assert span.shape[0] == self.args.n_codebooks, span.shape + unshifted_span = [] + for j, s in enumerate(span): + start_from = j + end_at = - (self.args.n_codebooks - start_from) + unshifted_span.append(s[start_from:end_at]) + unshifted_span = torch.stack(unshifted_span, dim=0) + + assert unshifted_span.shape[1] == num_gen[l] - self.args.n_codebooks, f"len(unshifted_spans[0]): {len(unshifted_span[0])}, num_gen[l]: {num_gen[l]}" + + flatten_gen.append(unshifted_span) + assert len(flatten_gen) == 1, len(flatten_gen) + + # combine + res = [y[0], flatten_gen[0]] + res = torch.cat(res, dim=1).unsqueeze(0) # [K, new_t] -> [1, K, new_T] + + expected_y_len = y_len + sum([item - self.args.n_codebooks for item in num_gen]) + assert res.shape == torch.Size((1, self.args.n_codebooks, expected_y_len)), f"res.shape: {res.shape}, expected_y_len: {expected_y_len}. y_len + sum([item - self.args.n_codebooks for item in num_gen]): {y_len} + {sum([item - self.args.n_codebooks for item in num_gen])}" + + if self.args.special_first: + res = res - int(self.args.n_special) + flatten_gen = flatten_gen - int(self.args.n_special) + + return res, flatten_gen[0].unsqueeze(0) \ No newline at end of file diff --git a/steps/__init__.py b/steps/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/steps/optim.py b/steps/optim.py new file mode 100644 index 0000000..88bd02b --- /dev/null +++ b/steps/optim.py @@ -0,0 +1,1123 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import logging +import random +from collections import defaultdict +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch import Tensor +from torch.optim import Optimizer + + +class BatchedOptimizer(Optimizer): + """ + This class adds to class Optimizer the capability to optimize parameters in batches: + it will stack the parameters and their grads for you so the optimizer can work + on tensors with an extra leading dimension. This is intended for speed with GPUs, + as it reduces the number of kernels launched in the optimizer. + + Args: + params: + """ + + def __init__(self, params, defaults): + super(BatchedOptimizer, self).__init__(params, defaults) + + @contextlib.contextmanager + def batched_params(self, param_group, group_params_names): + """ + This function returns (technically, yields) a list of + of tuples (p, state), where + p is a `fake` parameter that is stacked (over axis 0) from real parameters + that share the same shape, and its gradient is also stacked; + `state` is the state corresponding to this batch of parameters + (it will be physically located in the "state" for one of the real + parameters, the last one that has any particular shape and dtype). + + This function is decorated as a context manager so that it can + write parameters back to their "real" locations. + + The idea is, instead of doing: + + for p in group["params"]: + state = self.state[p] + ... + + you can do: + + with self.batched_params(group["params"]) as batches: + for p, state, p_names in batches: + ... + + + Args: + group: a parameter group, which is a list of parameters; should be + one of self.param_groups. + group_params_names: name for each parameter in group, + which is List[str]. + """ + batches = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter + batches_names = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str + + assert len(param_group) == len(group_params_names), f"len(param_group): {len(param_group)}, len(group_params_names): {len(group_params_names)}" + for p, named_p in zip(param_group, group_params_names): + key = (str(p.dtype), *p.shape) + batches[key].append(p) + batches_names[key].append(named_p) + + batches_names_keys = list(batches_names.keys()) + sorted_idx = sorted( + range(len(batches_names)), key=lambda i: batches_names_keys[i] + ) + batches_names = [ + batches_names[batches_names_keys[idx]] for idx in sorted_idx + ] + batches = [batches[batches_names_keys[idx]] for idx in sorted_idx] + + stacked_params_dict = dict() + + # turn batches into a list, in deterministic order. + # tuples will contain tuples of (stacked_param, state, stacked_params_names), + # one for each batch in `batches`. + tuples = [] + + for batch, batch_names in zip(batches, batches_names): + p = batch[0] + # we arbitrarily store the state in the + # state corresponding to the 1st parameter in the + # group. class Optimizer will take care of saving/loading state. + state = self.state[p] + p_stacked = torch.stack(batch) + grad = torch.stack( + [ + torch.zeros_like(p) if p.grad is None else p.grad + for p in batch + ] + ) + p_stacked.grad = grad + stacked_params_dict[key] = p_stacked + tuples.append((p_stacked, state, batch_names)) + + yield tuples # <-- calling code will do the actual optimization here! + + for ((stacked_params, _state, _names), batch) in zip(tuples, batches): + for i, p in enumerate(batch): # batch is list of Parameter + p.copy_(stacked_params[i]) + + +class ScaledAdam(BatchedOptimizer): + """ + Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update + proportional to the norm of that parameter; and also learn the scale of the parameter, + in log space, subject to upper and lower limits (as if we had factored each parameter as + param = underlying_param * log_scale.exp()) + + + Args: + params: The parameters or param_groups to optimize (like other Optimizer subclasses) + lr: The learning rate. We will typically use a learning rate schedule that starts + at 0.03 and decreases over time, i.e. much higher than other common + optimizers. + clipping_scale: (e.g. 2.0) + A scale for gradient-clipping: if specified, the normalized gradients + over the whole model will be clipped to have 2-norm equal to + `clipping_scale` times the median 2-norm over the most recent period + of `clipping_update_period` minibatches. By "normalized gradients", + we mean after multiplying by the rms parameter value for this tensor + [for non-scalars]; this is appropriate because our update is scaled + by this quantity. + betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. + Must satisfy 0 < beta <= beta2 < 1. + scalar_lr_scale: A scaling factor on the learning rate, that we use to update the + scale of each parameter tensor and scalar parameters of the mode.. + If each parameter were decomposed + as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale + would be a the scaling factor on the learning rate of p_scale. + eps: A general-purpose epsilon to prevent division by zero + param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be >= this value) + param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be <= this value) + scalar_max: Maximum absolute value for scalar parameters (applicable if your + model has any parameters with numel() == 1). + size_update_period: The periodicity, in steps, with which we update the size (scale) + of the parameter tensor. This is provided to save a little time + in the update. + clipping_update_period: if clipping_scale is specified, this is the period + """ + + def __init__( + self, + params, + lr=3e-02, + clipping_scale=None, + betas=(0.9, 0.98), + scalar_lr_scale=0.1, + eps=1.0e-08, + param_min_rms=1.0e-05, + param_max_rms=3.0, + scalar_max=10.0, + size_update_period=4, + clipping_update_period=100, + parameters_names=None, + show_dominant_parameters=True, + ): + + assert parameters_names is not None, ( + "Please prepare parameters_names," + "which is a List[List[str]]. Each List[str] is for a group" + "and each str is for a parameter" + ) + defaults = dict( + lr=lr, + clipping_scale=clipping_scale, + betas=betas, + scalar_lr_scale=scalar_lr_scale, + eps=eps, + param_min_rms=param_min_rms, + param_max_rms=param_max_rms, + scalar_max=scalar_max, + size_update_period=size_update_period, + clipping_update_period=clipping_update_period, + ) + + super(ScaledAdam, self).__init__(params, defaults) + assert len(self.param_groups) == len(parameters_names) + self.parameters_names = parameters_names + self.show_dominant_parameters = show_dominant_parameters + + def __setstate__(self, state): + super(ScaledAdam, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + batch = True + + for group, group_params_names in zip( + self.param_groups, self.parameters_names + ): + + with self.batched_params( + group["params"], group_params_names + ) as batches: + + # batches is list of pairs (stacked_param, state). stacked_param is like + # a regular parameter, and will have a .grad, but the 1st dim corresponds to + # a stacking dim, it is not a real dim. + + if ( + len(batches[0][1]) == 0 + ): # if len(first state) == 0: not yet initialized + clipping_scale = 1 + else: + clipping_scale = self._get_clipping_scale(group, batches) + + for p, state, _ in batches: + # Perform optimization step. + # grad is not going to be None, we handled that when creating the batches. + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "ScaledAdam optimizer does not support sparse gradients" + ) + # State initialization + if len(state) == 0: + self._init_state(group, p, state) + + self._step_one_batch(group, p, state, clipping_scale) + + return loss + + def _init_state(self, group: dict, p: Tensor, state: dict): + """ + Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p + is actually the batch dimension, corresponding to batched-together + parameters of a given shape. + + + Args: + group: Dict to look up configuration values. + p: The parameter that we are initializing the state for + state: Dict from string to whatever state we are initializing + """ + size_update_period = group["size_update_period"] + + state["step"] = 0 + + kwargs = {"device": p.device, "dtype": p.dtype} + + # 'delta' implements conventional momentum. There are + # several different kinds of update going on, so rather than + # compute "exp_avg" like in Adam, we store and decay a + # parameter-change "delta", which combines all forms of + # update. this is equivalent to how it's done in Adam, + # except for the first few steps. + state["delta"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + batch_size = p.shape[0] + numel = p.numel() // batch_size + numel = p.numel() + + if numel > 1: + # "param_rms" just periodically records the scalar root-mean-square value of + # the parameter tensor. + # it has a shape like (batch_size, 1, 1, 1, 1) + param_rms = ( + (p ** 2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + ) + state["param_rms"] = param_rms + + state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) + state["scale_grads"] = torch.zeros( + size_update_period, *param_rms.shape, **kwargs + ) + + # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + def _get_clipping_scale( + self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]] + ) -> float: + """ + Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients + by this amount before applying the rest of the update. + + Args: + group: the parameter group, an item in self.param_groups + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". + """ + assert len(tuples) >= 1 + clipping_scale = group["clipping_scale"] + (first_p, first_state, _) = tuples[0] + step = first_state["step"] + if clipping_scale is None or step == 0: + # no clipping. return early on step == 0 because the other + # parameters' state won't have been initialized yet. + return 1.0 + clipping_update_period = group["clipping_update_period"] + + tot_sumsq = torch.tensor(0.0, device=first_p.device) + for (p, state, param_names) in tuples: + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "ScaledAdam optimizer does not support sparse gradients" + ) + if p.numel() == p.shape[0]: # a batch of scalars + tot_sumsq += ( + grad ** 2 + ).sum() # sum() to change shape [1] to [] + else: + tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() + + tot_norm = tot_sumsq.sqrt() + if "model_norms" not in first_state: + first_state["model_norms"] = torch.zeros( + clipping_update_period, device=p.device + ) + first_state["model_norms"][step % clipping_update_period] = tot_norm + + if step % clipping_update_period == 0: + # Print some stats. + # We don't reach here if step == 0 because we would have returned + # above. + sorted_norms = first_state["model_norms"].sort()[0].to("cpu") + quartiles = [] + for n in range(0, 5): + index = min( + clipping_update_period - 1, + (clipping_update_period // 4) * n, + ) + quartiles.append(sorted_norms[index].item()) + + median = quartiles[2] + threshold = clipping_scale * median + first_state["model_norm_threshold"] = threshold + percent_clipped = ( + first_state["num_clipped"] * 100.0 / clipping_update_period + if "num_clipped" in first_state + else 0.0 + ) + first_state["num_clipped"] = 0 + quartiles = " ".join(["%.3e" % x for x in quartiles]) + logging.info( + f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " + f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" + ) + + if step < clipping_update_period: + return 1.0 # We have not yet estimated a norm to clip to. + else: + try: + model_norm_threshold = first_state["model_norm_threshold"] + except KeyError: + logging.info( + "Warning: model_norm_threshold not in state: possibly " + "you changed config when restarting, adding clipping_scale option?" + ) + return 1.0 + ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) + if ans < 1.0: + first_state["num_clipped"] += 1 + if ans < 0.1: + logging.warn( + f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" + ) + if self.show_dominant_parameters: + assert p.shape[0] == len(param_names) + self._show_gradient_dominating_parameter(tuples, tot_sumsq) + return ans + + def _show_gradient_dominating_parameter( + self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor + ): + """ + Show information of parameter wihch dominanting tot_sumsq. + + Args: + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". + tot_sumsq: sumsq of all parameters. Though it's could be calculated + from tuples, we still pass it to save some time. + """ + all_sumsq_orig = {} + for (p, state, batch_param_names) in tuples: + # p is a stacked batch parameters. + batch_grad = p.grad + if p.numel() == p.shape[0]: # a batch of scalars + batch_sumsq_orig = batch_grad ** 2 + # Dummpy values used by following `zip` statement. + batch_rms_orig = torch.ones(p.shape[0]) + else: + batch_rms_orig = state["param_rms"] + batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum( + dim=list(range(1, batch_grad.ndim)) + ) + + for name, sumsq_orig, rms, grad in zip( + batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad + ): + + proportion_orig = sumsq_orig / tot_sumsq + all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) + + assert torch.isclose( + sum([value[0] for value in all_sumsq_orig.values()]).cpu(), + torch.tensor(1.0), + ) + sorted_by_proportion = { + k: v + for k, v in sorted( + all_sumsq_orig.items(), + key=lambda item: item[1][0], + reverse=True, + ) + } + dominant_param_name = next(iter(sorted_by_proportion)) + ( + dominant_proportion, + dominant_sumsq, + dominant_rms, + dominant_grad, + ) = sorted_by_proportion[dominant_param_name] + logging.info( + f"Parameter Dominanting tot_sumsq {dominant_param_name}" + f" with proportion {dominant_proportion:.2f}," + f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" + f"={dominant_sumsq:.3e}," + f" grad_sumsq = {(dominant_grad**2).sum():.3e}," + f" orig_rms_sq={(dominant_rms**2).item():.3e}" + ) + + def _step_one_batch( + self, group: dict, p: Tensor, state: dict, clipping_scale: float + ): + """ + Do the step for one parameter, which is actually going to be a batch of + `real` parameters, with dim 0 as the batch dim. + Args: + group: dict to look up configuration values + p: parameter to update (actually multiple parameters stacked together + as a batch) + state: state-dict for p, to look up the optimizer state + """ + lr = group["lr"] + size_update_period = group["size_update_period"] + beta1 = group["betas"][0] + + grad = p.grad + if clipping_scale != 1.0: + grad = grad * clipping_scale + step = state["step"] + delta = state["delta"] + + delta.mul_(beta1) + batch_size = p.shape[0] + numel = p.numel() // batch_size + if numel > 1: + # Update the size/scale of p, and set param_rms + scale_grads = state["scale_grads"] + scale_grads[step % size_update_period] = (p * grad).sum( + dim=list(range(1, p.ndim)), keepdim=True + ) + if step % size_update_period == size_update_period - 1: + param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) + param_rms.copy_( + (p ** 2) + .mean(dim=list(range(1, p.ndim)), keepdim=True) + .sqrt() + ) + if step > 0: + # self._size_update() learns the overall scale on the + # parameter, by shrinking or expanding it. + self._size_update(group, scale_grads, p, state) + + if numel == 1: + # For parameters with 1 element we just use regular Adam. + # Updates delta. + self._step_scalar(group, p, state) + else: + self._step(group, p, state) + + state["step"] = step + 1 + + def _size_update( + self, group: dict, scale_grads: Tensor, p: Tensor, state: dict + ) -> None: + """ + Called only where p.numel() > 1, this updates the scale of the parameter. + If we imagine: p = underlying_param * scale.exp(), and we are doing + gradient descent on underlying param and on scale, this function does the update + on `scale`. + + Args: + group: dict to look up configuration values + scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing + grads w.r.t. the scales. + p: The parameter to update + state: The state-dict of p + """ + + param_rms = state["param_rms"] + beta1, beta2 = group["betas"] + size_lr = group["lr"] * group["scalar_lr_scale"] + param_min_rms = group["param_min_rms"] + param_max_rms = group["param_max_rms"] + eps = group["eps"] + step = state["step"] + batch_size = p.shape[0] + + size_update_period = scale_grads.shape[0] + # correct beta2 for the size update period: we will have + # faster decay at this level. + beta2_corr = beta2 ** size_update_period + + scale_exp_avg_sq = state[ + "scale_exp_avg_sq" + ] # shape: (batch_size, 1, 1, ..) + scale_exp_avg_sq.mul_(beta2_corr).add_( + (scale_grads ** 2).mean( + dim=0 + ), # mean over dim `size_update_period` + alpha=1 - beta2_corr, + ) # shape is (batch_size, 1, 1, ...) + + # The 1st time we reach here is when size_step == 1. + size_step = (step + 1) // size_update_period + bias_correction2 = 1 - beta2_corr ** size_step + # we don't bother with bias_correction1; this will help prevent divergence + # at the start of training. + + denom = scale_exp_avg_sq.sqrt() + eps + + scale_step = ( + -size_lr + * (bias_correction2 ** 0.5) + * scale_grads.sum(dim=0) + / denom + ) + + is_too_small = param_rms < param_min_rms + is_too_large = param_rms > param_max_rms + + # when the param gets too small, just don't shrink it any further. + scale_step.masked_fill_(is_too_small, 0.0) + # when it gets too large, stop it from getting any larger. + scale_step.masked_fill_(is_too_large, -size_lr * size_update_period) + delta = state["delta"] + # the factor of (1-beta1) relates to momentum. + delta.add_(p * scale_step, alpha=(1 - beta1)) + + def _step(self, group: dict, p: Tensor, state: dict): + """ + This function does the core update of self.step(), in the case where the members of + the batch have more than 1 element. + + Args: + group: A dict which will be used to look up configuration values + p: The parameter to be updated + grad: The grad of p + state: The state-dict corresponding to parameter p + + This function modifies p. + """ + grad = p.grad + lr = group["lr"] + beta1, beta2 = group["betas"] + eps = group["eps"] + param_min_rms = group["param_min_rms"] + step = state["step"] + + exp_avg_sq = state["exp_avg_sq"] + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) + + this_step = state["step"] - ( + state["zero_step"] if "zero_step" in state else 0 + ) + bias_correction2 = 1 - beta2 ** (this_step + 1) + if bias_correction2 < 0.99: + # note: not in-place. + exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) + + denom = exp_avg_sq.sqrt() + denom += eps + grad = grad / denom + + alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) + + delta = state["delta"] + delta.add_(grad * alpha) + p.add_(delta) + + def _step_scalar(self, group: dict, p: Tensor, state: dict): + """ + A simplified form of the core update for scalar tensors, where we cannot get a good + estimate of the parameter rms. + """ + beta1, beta2 = group["betas"] + scalar_max = group["scalar_max"] + eps = group["eps"] + lr = group["lr"] * group["scalar_lr_scale"] + grad = p.grad + + exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # bias_correction2 is like in Adam. Don't bother with bias_correction1; + # slower update at the start will help stability anyway. + bias_correction2 = 1 - beta2 ** (state["step"] + 1) + denom = (exp_avg_sq / bias_correction2).sqrt() + eps + + delta = state["delta"] + delta.add_(grad / denom, alpha=-lr * (1 - beta1)) + p.clamp_(min=-scalar_max, max=scalar_max) + p.add_(delta) + + +class LRScheduler(object): + """ + Base-class for learning rate schedulers where the learning-rate depends on both the + batch and the epoch. + """ + + def __init__(self, optimizer: Optimizer, verbose: bool = False): + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError( + "{} is not an Optimizer".format(type(optimizer).__name__) + ) + self.optimizer = optimizer + self.verbose = verbose + + for group in optimizer.param_groups: + group.setdefault("base_lr", group["lr"]) + + self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] + + self.epoch = 0 + self.batch = 0 + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return { + "base_lrs": self.base_lrs, + "epoch": self.epoch, + "batch": self.batch, + } + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_lr(self) -> List[float]: + """Return last computed learning rate by current scheduler. Will be a list of float.""" + return self._last_lr + + def get_lr(self): + # Compute list of learning rates from self.epoch and self.batch and + # self.base_lrs; this must be overloaded by the user. + # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] + raise NotImplementedError + + def step_batch(self, batch: Optional[int] = None) -> None: + # Step the batch index, or just set it. If `batch` is specified, it + # must be the batch index from the start of training, i.e. summed over + # all epochs. + # You can call this in any order; if you don't provide 'batch', it should + # of course be called once per batch. + if batch is not None: + self.batch = batch + else: + self.batch = self.batch + 1 + self._set_lrs() + + def step_epoch(self, epoch: Optional[int] = None): + # Step the epoch index, or just set it. If you provide the 'epoch' arg, + # you should call this at the start of the epoch; if you don't provide the 'epoch' + # arg, you should call it at the end of the epoch. + if epoch is not None: + self.epoch = epoch + else: + self.epoch = self.epoch + 1 + self._set_lrs() + + def _set_lrs(self): + values = self.get_lr() + assert len(values) == len(self.optimizer.param_groups) + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group["lr"] = lr + self.print_lr(self.verbose, i, lr) + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + + def print_lr(self, is_verbose, group, lr): + """Display the current learning rate.""" + if is_verbose: + logging.info( + f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" + f" of group {group} to {lr:.4e}." + ) + + +class Eden(LRScheduler): + """ + Eden scheduler. + The basic formula (before warmup) is: + lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 * + (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup + where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches + and then stays constant at 1. + + + E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam + + Args: + optimizer: the optimizer to change the learning rates on + lr_batches: the number of batches after which we start significantly + decreasing the learning rate, suggest 5000. + lr_epochs: the number of epochs after which we start significantly + decreasing the learning rate, suggest 6 if you plan to do e.g. + 20 to 40 epochs, but may need smaller number if dataset is huge + and you will do few epochs. + """ + + def __init__( + self, + optimizer: Optimizer, + lr_batches: Union[int, float], + lr_epochs: Union[int, float], + warmup_batches: Union[int, float] = 500.0, + verbose: bool = False, + ): + super(Eden, self).__init__(optimizer, verbose) + self.lr_batches = lr_batches + self.lr_epochs = lr_epochs + self.warmup_batches = warmup_batches + + def get_lr(self): + factor = ( + (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 + ) ** -0.25 * ( + ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) + ** -0.25 + ) + warmup_factor = ( + 1.0 + if self.batch >= self.warmup_batches + else 0.5 + 0.5 * (self.batch / self.warmup_batches) + ) + + return [x * factor * warmup_factor for x in self.base_lrs] + + +def _test_eden(): + m = torch.nn.Linear(100, 100) + optim = ScaledAdam(m.parameters(), lr=0.03) + + scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True) + + for epoch in range(10): + scheduler.step_epoch(epoch) # sets epoch to `epoch` + + for step in range(20): + x = torch.randn(200, 100).detach() + x.requires_grad = True + y = m(x) + dy = torch.randn(200, 100).detach() + f = (y * dy).sum() + f.backward() + + optim.step() + scheduler.step_batch() + optim.zero_grad() + + logging.info(f"last lr = {scheduler.get_last_lr()}") + logging.info(f"state dict = {scheduler.state_dict()}") + + +# This is included mostly as a baseline for ScaledAdam. +class Eve(Optimizer): + """ + Implements Eve algorithm. This is a modified version of AdamW with a special + way of setting the weight-decay / shrinkage-factor, which is designed to make the + rms of the parameters approach a particular target_rms (default: 0.1). This is + for use with networks with 'scaled' versions of modules (see scaling.py), which + will be close to invariant to the absolute scale on the parameter matrix. + + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + Eve is unpublished so far. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 3e-4; + this value means that the weight would decay significantly after + about 3k minibatches. Is not multiplied by learning rate, but + is conditional on RMS-value of parameter being > target_rms. + target_rms (float, optional): target root-mean-square value of + parameters, if they fall below this we will stop applying weight decay. + + + .. _Adam: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.98), + eps=1e-8, + weight_decay=1e-3, + target_rms=0.1, + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError( + "Invalid beta parameter at index 0: {}".format(betas[0]) + ) + if not 0.0 <= betas[1] < 1.0: + raise ValueError( + "Invalid beta parameter at index 1: {}".format(betas[1]) + ) + if not 0 <= weight_decay <= 0.1: + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) + if not 0 < target_rms <= 10.0: + raise ValueError("Invalid target_rms value: {}".format(target_rms)) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + target_rms=target_rms, + ) + super(Eve, self).__init__(params, defaults) + + def __setstate__(self, state): + super(Eve, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + # Perform optimization step + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "AdamW does not support sparse gradients" + ) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + beta1, beta2 = group["betas"] + + state["step"] += 1 + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( + group["eps"] + ) + + step_size = group["lr"] / bias_correction1 + target_rms = group["target_rms"] + weight_decay = group["weight_decay"] + + if p.numel() > 1: + # avoid applying this weight-decay on "scaling factors" + # (which are scalar). + is_above_target_rms = p.norm() > ( + target_rms * (p.numel() ** 0.5) + ) + p.mul_(1 - (weight_decay * is_above_target_rms)) + + p.addcdiv_(exp_avg, denom, value=-step_size) + + # if random.random() < 0.0005: + # step = (exp_avg / denom) * step_size + # logging.info( + # f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}" + # ) + + return loss + +def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: + """ + Behaves like a constructor of a modified version of nn.Linear + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Linear(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_( + ans.bias, -0.1 * initial_scale, 0.1 * initial_scale + ) + return ans +def _test_scaled_adam(hidden_dim: int): + import timeit + + E = 100 + B = 4 + T = 2 + logging.info("in test_eve_cain") + # device = torch.device('cuda') + device = torch.device("cpu") + dtype = torch.float32 + + # these input_magnitudes and output_magnitudes are to test that + # Abel is working as we expect and is able to adjust scales of + # different dims differently. + input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + + for iter in [1, 0]: + Linear = torch.nn.Linear if iter == 0 else ScaledLinear + + m = torch.nn.Sequential( + Linear(E, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, E), + ).to(device) + + train_pairs = [ + ( + 100.0 + * torch.randn(B, T, E, device=device, dtype=dtype) + * input_magnitudes, + torch.randn(B, T, E, device=device, dtype=dtype) + * output_magnitudes, + ) + for _ in range(20) + ] + + if iter == 0: + optim = Eve(m.parameters(), lr=0.003) + elif iter == 1: + optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) + scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) + + start = timeit.default_timer() + avg_loss = 0.0 + for epoch in range(180): + scheduler.step_epoch() + # if epoch == 100 and iter in [2,3]: + # optim.reset_speedup() # check it doesn't crash. + + # if epoch == 130: + # opts = diagnostics.TensorDiagnosticOptions( + # 2 ** 22 + # ) # allow 4 megabytes per sub-module + # diagnostic = diagnostics.attach_diagnostics(m, opts) + + for n, (x, y) in enumerate(train_pairs): + y_out = m(x) + loss = ((y_out - y) ** 2).mean() * 100.0 + if epoch == 0 and n == 0: + avg_loss = loss.item() + else: + avg_loss = 0.98 * avg_loss + 0.02 * loss.item() + if n == 0 and epoch % 5 == 0: + # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() + # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() + # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() + # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() + # scale1 = '%.2e' % (m[0].weight_scale.exp().item()) + # scale1b = '%.2e' % (m[0].bias_scale.exp().item()) + # scale2 = '%.2e' % (m[2].weight_scale.exp().item()) + # scale2b = '%.2e' % (m[2].bias_scale.exp().item()) + lr = scheduler.get_last_lr()[0] + logging.info( + f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}" + ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} + loss.log().backward() + optim.step() + optim.zero_grad() + scheduler.step_batch() + + # diagnostic.print_diagnostics() + + stop = timeit.default_timer() + logging.info(f"Iter={iter}, Time taken: {stop - start}") + + logging.info(f"last lr = {scheduler.get_last_lr()}") + # logging.info("state dict = ", scheduler.state_dict()) + # logging.info("optim state_dict = ", optim.state_dict()) + logging.info(f"input_magnitudes = {input_magnitudes}") + logging.info(f"output_magnitudes = {output_magnitudes}") + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + logging.getLogger().setLevel(logging.INFO) + import subprocess + + s = subprocess.check_output( + "git status -uno .; git log -1; git diff HEAD .", shell=True + ) + logging.info(s) + import sys + + if len(sys.argv) > 1: + hidden_dim = int(sys.argv[1]) + else: + hidden_dim = 200 + + _test_scaled_adam(hidden_dim) + _test_eden() diff --git a/steps/trainer.py b/steps/trainer.py new file mode 100644 index 0000000..0cce5de --- /dev/null +++ b/steps/trainer.py @@ -0,0 +1,467 @@ +import time +import os, random +import torch +import math, pickle +from tqdm import tqdm +from torch.optim import AdamW +from torch.optim.lr_scheduler import LambdaLR +import torch.nn as nn +import torch.distributed as dist +from torch.utils.tensorboard import SummaryWriter +import numpy as np +from torch.utils.data.distributed import DistributedSampler +import logging +from data import gigaspeech +from models import voicecraft + +from .trainer_utils import DistributedDynamicBatchSampler, StatefulDistributedSampler, AverageMeter, print_model_info +from .optim import ScaledAdam, Eden + + +class Trainer: + + def __init__(self, args, world_size, rank): + self.start_time = time.time() + self.args = args + self.world_size, self.rank = world_size, rank + self.device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu") + if self.rank == 0: + self.writer = SummaryWriter(args.exp_dir) + self.seed_everything(seed=self.args.seed) + self.meters = self._setup_meters() + + self.progress, self.total_progress = self._setup_progress() + + self.model, self.trainables, self.optim_states, self.scheduler_states = self._setup_models() + + self.train_dataset_length, self.train_sampler, self.train_loader, self.valid_loader = self._setup_dataloader() + if self.args.num_steps != None: + self.total_step = self.args.num_steps + self.args.num_epochs = math.ceil(self.total_step / math.floor(self.train_dataset_length / self.args.batch_size)) if not self.args.dynamic_batching else None + else: + self.total_step = int(math.floor(self.train_dataset_length / self.args.batch_size))*self.args.num_epochs + + self.optimizer, self.scheduler = self._setup_optimizer() + self.scaler = torch.cuda.amp.GradScaler() + self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[self.rank], find_unused_parameters=False) + + if self.rank == 0: + self.early_stop_accu_steps = 0 + if self.args.dynamic_batching: + logging.info(f"max number of tokens per GPU in a training batch: {self.args.max_num_tokens}, max number of tokens per GPU in a inference batch: {self.args.val_max_num_tokens}") + else: + logging.info(f"batch size (summed over all GPUs): {self.args.batch_size}") + + def train(self): + flag = True + skip_flag = False + data_start_time = time.time() + while flag: + self.train_sampler.set_epoch(self.progress['epoch']) + for i, batch in enumerate(self.train_loader): + data_end_time = time.time() + self.model.train() + if self.progress['step'] > self.total_step: + flag = False + self.validate_and_save() + if self.rank == 0: + self.writer.close() + break + if isinstance(self.scheduler, Eden): + self.scheduler.step_epoch(self.progress['step']//self.args.pseudo_epoch_size + 1) + if self.args.optimizer_name == "ScaledAdam": + cur_lr = self.scheduler.get_last_lr()[0] + else: + lrs = [param_group['lr'] for param_group in self.optimizer.param_groups] + assert lrs[0] == lrs[1] + cur_lr = lrs[0] + + if self.rank == 0 and self.progress['step'] % self.args.tb_write_every_n_steps == 0: + self.writer.add_scalar("train/lr", cur_lr, self.progress['step']) + self.wandb.log({"train/lr": cur_lr}, step=self.progress['step']) + + all_inds = list(range(len(batch['y']))) + sum_losses = 0 + sum_top10acc = 0 + sum_ntoken = 0 + sum_top10acc_cbi = [0 for _ in range(self.args.n_codebooks)] + for j in range(self.args.gradient_accumulation_steps): + cur_ind = all_inds[j::self.args.gradient_accumulation_steps] + cur_batch = {key: batch[key][cur_ind] for key in batch} + with torch.cuda.amp.autocast(dtype=torch.float16 if self.args.precision=="float16" else torch.float32): + out = self.model(cur_batch) + + record_loss = out['loss'].detach().to(self.rank) + top10acc = out['top10acc'].to(self.rank) + effective_ntoken = out['effective_ntoken'].to(self.rank) + is_nan = torch.tensor(int(torch.isnan(record_loss).any()), dtype=torch.float32, device=self.rank) + + dist.all_reduce(record_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(top10acc, op=dist.ReduceOp.SUM) + dist.all_reduce(effective_ntoken, op=dist.ReduceOp.SUM) + dist.all_reduce(is_nan, op=dist.ReduceOp.SUM) + + # check if loss is nan + if is_nan.item() > 0: + logging.info(f"loss at step {self.progress['step']} is nan, therefore skip this batch") + skip_flag = True + continue + + sum_losses += record_loss.item() + sum_top10acc += top10acc.item() + sum_ntoken += effective_ntoken.item() + + if 'top10acc_by_codebook' in out: + for cb in range(self.args.n_codebooks): + top10acc_cbi = out['top10acc_by_codebook'][cb] + dist.all_reduce(top10acc_cbi, op=dist.ReduceOp.SUM) + sum_top10acc_cbi[cb] += top10acc_cbi.item() + + if self.rank == 0: + average_loss = sum_losses / sum_ntoken + average_top10acc = sum_top10acc / sum_ntoken + self.meters['train_loss'].update(average_loss, batch['x'].shape[0]*self.world_size) + self.meters['train_top10acc'].update(average_top10acc, batch['x'].shape[0]*self.world_size) + self.meters['train_top10acc'].update(average_top10acc, batch['x'].shape[0]*self.world_size) + average_top10acc_cbi = [sum_top10acc_cbi[cb] / sum_ntoken * self.args.n_codebooks for cb in range(self.args.n_codebooks)] + for cb in range(self.args.n_codebooks): + self.meters[f'train_top10acc_cb{cb+1}'].update(average_top10acc_cbi[cb], batch['x'].shape[0]*self.world_size) + + if self.progress['step'] % self.args.tb_write_every_n_steps == 0: + self.writer.add_scalar('train/loss', average_loss, self.progress['step']) + self.writer.add_scalar('train/top10acc', average_top10acc, self.progress['step']) + self.writer.add_scalar("train/ntokens", sum_ntoken, self.progress['step']) + for cb in range(self.args.n_codebooks): + self.writer.add_scalar(f'train/top10acc_cb{cb+1}', average_top10acc_cbi[cb], self.progress['step']) + + if self.args.optimizer_name == "ScaledAdam": + self.scaler.scale(out['loss']).backward() + else: + self.scaler.scale(out['loss']/out['effective_ntoken']).backward() + + if skip_flag: + self.optimizer.zero_grad() + skip_flag = False + continue + + if self.args.optimizer_name != "ScaledAdam": + self.scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.gradient_clip_val) + self.scaler.step(self.optimizer) + self.scaler.update() + + self.optimizer.zero_grad() + + if self.args.optimizer_name == "ScaledAdam": + self.scheduler.step_batch(self.progress['step']) + else: + self.scheduler.step() + + if self.rank == 0: + self.meters['data_time'].update(data_end_time - data_start_time) + self.meters['train_time'].update(time.time() - data_end_time) + if self.progress['step'] % self.args.tb_write_every_n_steps == 0: + self.writer.add_scalar("train/data_time", data_end_time - data_start_time, self.progress['step']) + self.writer.add_scalar("train/train_time", time.time() - data_end_time, self.progress['step']) + + + # logging + if self.progress['step'] % self.args.print_every_n_steps == 0: + log_out = {} + log_out['cur_epoch'] = f"{self.progress['epoch']}/{self.args.num_epochs}" if self.args.num_epochs is not None else f"{self.progress['epoch']}" + log_out['cur_step'] = f"{int(self.progress['cur_step']+1)}" + log_out['total_step'] = f"{self.progress['step']}/{self.args.num_steps}" + log_out['lr'] = f"{cur_lr:.7f}" + log_out['ntokens'] = f"{sum_ntoken}" + for key in self.meters: + if self.meters[key].val != 0 or self.meters[key].avg != 0: + log_out[key] = f"{self.meters[key].val:.4f} ({self.meters[key].avg:.4f})" if isinstance(self.meters[key].val, float) else f"{self.meters[key].val}" + logging.info(log_out) + if np.isnan(self.meters['train_loss'].avg): + logging.warning("training diverged...") + raise RuntimeError("training diverged...") + + # validation and save models + if self.progress['step'] % self.args.val_every_n_steps == 0: + dist.barrier() + self.validate_and_save() + + self.progress['step'] += 1 + self.progress['cur_step'] += 1 + + data_start_time = time.time() + self.progress['epoch'] += 1 + self.progress['cur_step'] = 0 # reset cur_step to be 0 + dist.destroy_process_group() + + def validate_and_save(self): + self.model.eval() + + score = self.validate(self.valid_loader) + + if self.rank == 0: + if self.args.early_stop_threshold > 0: + if self.progress['best_score'] - score < self.args.early_stop_threshold: + self.early_stop_accu_steps += self.args.val_every_n_steps + if self.early_stop_accu_steps >= self.args.early_stop_step-1: + logging.info(f"early stop based on self.args.early_stop_threshold: {self.args.early_stop_threshold}, and self.args.early_stop_step: {self.args.early_stop_step}") + logging.info(f"best validation score at step: {self.progress['best_step']}, and the score is {self.progress['best_score']:.4f}") + dist.destroy_process_group() + raise RuntimeError("early stop") + else: + self.early_stop_accu_steps = 0 + + if (score < self.progress['best_score']): + self.progress['best_step'] = self.progress['step'] + self.progress['best_score'] = score + save_path = os.path.join(self.args.exp_dir,"best_bundle.pth") + torch.save( + { + "model": self.model.module.state_dict(), + "optimizer": self.optimizer.state_dict(), + "scheduler": self.scheduler.state_dict(), + "config": self.args, + "phn2num": self.train_loader.dataset.phn2num + },save_path + ) + logging.info(f"save *best* models at {save_path} at global step {self.progress['step']}") + self._save_progress() + save_path = os.path.join(self.args.exp_dir,"bundle.pth") + torch.save( + { + "model": self.model.module.state_dict(), + "optimizer": self.optimizer.state_dict(), + "scheduler": self.scheduler.state_dict(), + "config": self.args, + "phn2num": self.train_loader.dataset.phn2num + },save_path + ) + logging.info(f"save models, indices, acc and other statistics at {save_path} and {self.args.exp_dir}/progress.pkl at global step {self.progress['step']}") + + dist.barrier() + + def validate(self, valid_loader=None, hide_progress=True): + if valid_loader == None: + valid_loader = self.valid_loader + self.model.eval() + + start_val_time = time.time() + sum_losses = 0 + sum_top10acc = 0 + sum_ntoken = 0 + sum_top10acc_cbi = [0 for _ in range(self.args.n_codebooks)] + + with torch.no_grad(): + for i, batch in enumerate(tqdm(valid_loader, disable=hide_progress)): + out = self.model(batch) + sum_losses += out['loss'] + sum_top10acc += out['top10acc'] + sum_ntoken += out['effective_ntoken'] + if 'top10acc_by_codebook' in out: + for cb in range(self.args.n_codebooks): + sum_top10acc_cbi[cb] += out['top10acc_by_codebook'][cb] + + dist.all_reduce(sum_losses, op=dist.ReduceOp.SUM) + dist.all_reduce(sum_top10acc, op=dist.ReduceOp.SUM) + dist.all_reduce(sum_ntoken, op=dist.ReduceOp.SUM) + + if 'top10acc_by_codebook' in out: + for cb in range(self.args.n_codebooks): + dist.all_reduce(sum_top10acc_cbi[cb], op=dist.ReduceOp.SUM) + + if self.rank == 0: + val_loss = sum_losses / sum_ntoken + val_top10acc = sum_top10acc / sum_ntoken + # logging + self.meters['val_loss'].update(val_loss) + logging.info(f"val loss: {val_loss:.5f}") + self.writer.add_scalar("val/loss", val_loss, self.progress['step']) + + self.meters['val_top10acc'].update(val_top10acc) + logging.info(f"val top10acc: {val_top10acc:.5f}") + self.writer.add_scalar("val/top10acc", val_top10acc, self.progress['step']) + for cb in range(self.args.n_codebooks): + average_top10acc_cbi = sum_top10acc_cbi[cb] / sum_ntoken * self.args.n_codebooks + self.meters[f'val_top10acc_cb{cb+1}'].update(average_top10acc_cbi) + self.writer.add_scalar(f'val/top10acc_cb{cb+1}', average_top10acc_cbi, self.progress['step']) + + logging.info(f"validation takes: {time.time() - start_val_time:.2f}s") + logging.info(f"Step [{self.progress['step']}/{self.total_step}]\t Time elapsed {(time.time() - self.start_time)/3600.:.2f}h, Val Loss: {val_loss:.4f}, Val Top10Acc: {val_top10acc:.4f}") + return val_loss.item() + else: + return None + + def _setup_meters(self): + meters = {} + meter_names = ['train_loss', 'val_loss', 'train_top10acc', 'val_top10acc', 'data_time', 'train_time'] + meter_names += ['train_dur_loss', 'train_dur_acc', 'val_dur_loss', 'val_dur_acc'] + meter_names += [f'train_top10acc_cb{cb+1}' for cb in range(self.args.n_codebooks)] + meter_names += [f'val_top10acc_cb{cb+1}' for cb in range(self.args.n_codebooks)] + for name in meter_names: + meters[name] = AverageMeter() + return meters + def _setup_progress(self): + progress = {} + progress['best_step'] = 1 + progress['best_score'] = np.inf # this records loss value + progress['step'] = 1 + progress['epoch'] = 1 + progress['cur_step'] = 0 # step in the current epoch, for resuming the sampler + total_progress = [] + # if self.args.resume or self.args.validate: + if self.args.resume: + progress_pkl = "%s/progress.pkl" % self.args.exp_dir + with open(progress_pkl, "rb") as f: + total_progress = pickle.load(f) + progress['best_step'], progress['best_score'], progress['step'], progress['epoch'], progress['cur_step'], _ = total_progress[-1] + if self.rank == 0: + logging.info("\nResume training from:") + logging.info(" epoch = %s" % progress['epoch']) + logging.info(" cur_step = %s" % progress['cur_step']) + logging.info(" step = %s" % progress['step']) + logging.info(" best_step = %s" % progress['best_step']) + logging.info(" best_score = %s" % progress['best_score']) + return progress, total_progress + + def _save_progress(self): + self.total_progress.append([self.progress['best_step'], self.progress['best_score'], int(self.progress['step']+1), self.progress['epoch'], int(self.progress['cur_step']+1), time.time() - self.start_time]) + with open("%s/progress.pkl" % self.args.exp_dir, "wb") as f: + pickle.dump(self.total_progress, f) + + def _setup_dataloader(self): + assert self.args.dataset == 'gigaspeech', "only gigaspeech is supported for now" + train_dataset, val_dataset = gigaspeech.dataset(self.args, 'train'), gigaspeech.dataset(self.args, 'validation') + + if self.args.dynamic_batching: + train_sampler = DistributedDynamicBatchSampler(train_dataset, self.args, num_replicas=self.world_size, rank=self.rank, shuffle=True, seed=self.args.seed, drop_last=True, lengths_list=train_dataset.lengths_list, verbose=True, epoch=0) + valid_sampler = DistributedDynamicBatchSampler(val_dataset, self.args, num_replicas=self.world_size, rank=self.rank, shuffle=True, seed=self.args.seed, drop_last=True, lengths_list=val_dataset.lengths_list, verbose=True, epoch=0) + else: + train_sampler = StatefulDistributedSampler(train_dataset, self.args.batch_size//self.world_size, num_replicas=self.world_size, rank=self.rank, shuffle=True, seed=self.args.seed, drop_last=True) + valid_sampler = DistributedSampler(val_dataset, num_replicas=self.world_size, rank=self.rank, shuffle=False, seed=self.args.seed, drop_last=False) + + if self.progress['step'] > 1: + train_sampler.set_epoch_resume(self.progress['epoch'], self.progress['cur_step']) + + if self.args.dynamic_batching: + train_loader = torch.utils.data.DataLoader(train_dataset, + batch_sampler=train_sampler, + num_workers=self.args.num_workers//self.world_size, + collate_fn=train_dataset.collate, persistent_workers=True + ) + valid_loader = torch.utils.data.DataLoader(val_dataset, + batch_sampler=valid_sampler, + num_workers=self.args.num_workers//self.world_size, + collate_fn=val_dataset.collate, persistent_workers=True + ) + else: + train_loader = torch.utils.data.DataLoader(train_dataset, + batch_size=self.args.batch_size//self.world_size, sampler=train_sampler, num_workers=self.args.num_workers//self.world_size, + collate_fn=train_dataset.collate, persistent_workers=True + ) + valid_loader = torch.utils.data.DataLoader(val_dataset, + batch_size=self.args.batch_size//self.world_size, sampler=valid_sampler, + num_workers=self.args.num_workers//self.world_size, + collate_fn=val_dataset.collate, persistent_workers=True + ) + return len(train_dataset), train_sampler, train_loader, valid_loader + + + + def _setup_models(self): + model = voicecraft.VoiceCraft(self.args) + + if self.rank == 0: + logging.info(model) + logging.info("model parameters") + print_model_info(model) + + if self.progress['step'] > 1: + bundle = torch.load(os.path.join(self.args.exp_dir, "bundle.pth"), map_location="cpu") + model.load_state_dict(bundle['model']) + optim_states = bundle['optimizer'] + scheduler_states = bundle['scheduler'] + if self.rank == 0: + logging.info("loaded parameters and data indices from epoch %d, global step %d" % (self.progress['epoch'], self.progress['step'])) + del bundle['model'] + else: + optim_states = None + scheduler_states = None + + if self.args.load_model_from != None and self.progress['step'] <= 1: + sd = torch.load(self.args.load_model_from, map_location="cpu")['model'] + model.load_state_dict(sd) + del sd + + if self.args.optimizer_name == "ScaledAdam": + trainables = [p for p in model.parameters() if p.requires_grad] + else: + no_decay = [".bias", ".audio_embeddings.weight", ".text_embeddings.weight", ".norm.weight", ".norm1.weight", ".norm2.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad], + "weight_decay": self.args.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad], + "weight_decay": 0.0, + }, + ] + if len(optimizer_grouped_parameters[1]['params']) == 0: + logging.info("there is no embedding weights, bias, and layernorm parameters in the model, which should be True, check model parameter names") + trainables = optimizer_grouped_parameters[0] + else: + trainables = optimizer_grouped_parameters + model.to(self.device) + + return model, trainables, optim_states, scheduler_states + + + def _setup_optimizer(self): + if self.args.optimizer_name == "ScaledAdam": + parameters_names = [] + parameters_names.append([n for n,p in self.model.named_parameters() if p.requires_grad]) + optimizer = ScaledAdam( + self.trainables, + lr=self.args.lr, + betas=(0.9, 0.95), + clipping_scale=2.0, + parameters_names=parameters_names, + show_dominant_parameters=False, + clipping_update_period=self.args.clipping_update_period, + ) + scheduler = Eden(optimizer, self.args.reduce_lr_start_step, self.args.reduce_lr_start_epoch, warmup_batches=self.total_step * self.args.warmup_fraction) + + else: + optimizer = AdamW(self.trainables, lr=self.args.lr) + warmup_steps = self.total_step * self.args.warmup_fraction + def lr_lambda(current_step: int): + if current_step < warmup_steps: + return float(current_step) / float(max(1, warmup_steps)) + return max( + 0.0, float(self.total_step - current_step) / float(max(1, self.total_step - warmup_steps)) + ) + + scheduler = LambdaLR(optimizer, lr_lambda, last_epoch=-1) + + # if resume + if self.progress['step'] > 1: + optimizer.load_state_dict(self.optim_states) + for state in optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.cuda() + del self.optim_states + + scheduler.load_state_dict(self.scheduler_states) + + optimizer.zero_grad() + return optimizer, scheduler + + def seed_everything(self, seed=1): + os.environ['PYTHONHASHSEED'] = str(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True \ No newline at end of file diff --git a/steps/trainer_utils.py b/steps/trainer_utils.py new file mode 100644 index 0000000..65e2d14 --- /dev/null +++ b/steps/trainer_utils.py @@ -0,0 +1,628 @@ + +import torch +import math +import torch.distributed as dist +from torch.utils.data.sampler import Sampler +import copy +import numpy as np +from typing import List +from scipy.stats import lognorm +import logging + +class StatefulDistributedSampler(Sampler[int]): + def __init__(self, dataset, batch_size, num_replicas = None, rank = None, shuffle = True, seed = 0, drop_last = False): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + if rank >= num_replicas or rank < 0: + raise ValueError( + "Invalid rank {}, rank should be in the interval" + " [0, {}]".format(rank, num_replicas - 1)) + self.dataset = dataset + self.batch_size = batch_size + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.cur_epoch = 0 + self.drop_last = drop_last + # If the dataset length is evenly divisible by # of replicas, then there + # is no need to drop any data, since the dataset will be split equally. + if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] + # Split to nearest available length that is evenly divisible. + # This is to ensure each rank receives the same amount of data when + # using this Sampler. + self.num_samples = math.ceil( + (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] + ) + else: + self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + self.seed = seed + self.continue_flag = False + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + r""" + Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas + use a different random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch + + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] + else: + # remove tail of data to make it evenly divisible. + indices = indices[:self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + self.indices = indices + + if self.continue_flag: + self.indices = self.indices[int(self.cur_step*self.batch_size):] + self.num_samples = len(self.indices) + self.continue_flag = False + + def __iter__(self): + for idx in self.indices: + yield idx + + def set_epoch_resume(self, epoch, cur_step): + self.epoch = epoch + self.cur_step = cur_step + self.continue_flag = True + + +class StatefulSampler(Sampler): + def __init__(self, data_source_length, batch_size, use_random=True, seed=1, epoch=0): + self.use_random = use_random + self.data_source_length = data_source_length + self.num_samples = self.data_source_length + self.batch_size = batch_size + self.continue_flag = False + self.seed = seed + self.epoch = epoch + self.cur_step = 0 + + def __len__(self): + return self.num_samples + + def __iter__(self): + + for idx in self.indices: + yield idx + + def set_epoch(self, epoch): + self.epoch = epoch + if self.use_random: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + self.indices = torch.randperm(self.data_source_length, generator=g).tolist() # type: ignore[arg-type] + else: + self.indices = list(range(self.data_source_length)) # type: ignore[arg-type] + if self.continue_flag == True: + self.continue_flag = False + self.indices = self.indices[int(self.cur_step*self.batch_size):] + + self.num_samples = len(self.indices) + + def set_epoch_resume(self, epoch, cur_step): + self.epoch = epoch + self.cur_step = cur_step + self.continue_flag = True + + +class AverageMeter: + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + +def print_model_info(model, print_model = False, print_params = True): + if print_model: + logging.info(model) + if print_params: + all_params = {} + for name, p in model.named_parameters(): + name = name.split(".")[0] + if name in all_params: + all_params[name] += p.numel() + else: + all_params[name] = p.numel() + logging.info("num of parameters of each components:") + for name in all_params: + logging.info(f"{name}: {all_params[name]/1000000.:.2f}m") + + +class DistributedDynamicBatchSampler(Sampler): + """ + modified from SpeechBrian, https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/dataio/sampler.py#L307 + This BatchSampler batches examples together by grouping them by their length. + + Every example in the batch have approximately the same length and + thus padding is minimized. + This enables faster training on datasets + where length of examples can vary significantly (e.g Librispeech). + Inspired by: https://www.tensorflow.org/api_docs/python/tf/data/experimental/bucket_by_sequence_length + + Dynamic batching is performed by specifying a max_batch_length which is the + upper limit for the sum of the length of examples in a batch: + e.g., if ex1 has length 4, ex2 length 5 and if max_batch_length is set to 6 + ex1 and ex2 will be placed, alone, in two distinct batches. + + Length for each example can be obtained in two manners. + If the input dataset is a DynamicItemDataset it can be obtained by specifying a + length_func. Default assumes a "duration" entry is in the annotation. + Length for each example can also be passed to this class upon instantiation + by specifying a list containing the length for each example and passing it to + lengths_list. + + Examples are grouped together by defining a set of possible discrete intervals + (buckets). Examples whose length fall into these intervals can be batched together. + + The number of buckets can be specified by using the arg num_buckets. + There is usually an optimal range for the value of this argument. + + If num_buckets == 1, all examples can be batched together. You have maximum randomization + but your training speed will be slower due to the fact that a large amount of the values will be padding + as long and short examples can be batched together. + As the number of buckets grows only examples with similar + length can be grouped together. + This trades-off speed with randomization. + TLDR: Low number -> better randomization, High number -> faster training. + NOTE THAT: if set too high the training speed will decrease. If num_buckets -> number of examples in the dataset the batch size + will be small impacting training speed and possibly performance. + + The buckets can also be specified by passing a list to the bucket_boundaries + argument instead of specifying a left_bucket_length and a bucket_length_multiplier. + + Example + ------- + >>> import torch + >>> import speechbrain as sb + >>> from speechbrain.dataio.sampler import DynamicBatchSampler + >>> from speechbrain.dataio.dataset import DynamicItemDataset + >>> from speechbrain.dataio.dataloader import SaveableDataLoader + >>> from speechbrain.dataio.batch import PaddedBatch + >>> import numpy as np + >>> item_lengths = sorted([np.random.randint(10, 100) for x in range(20)]) + >>> dataset = {"ex_{}".format(x) : {"wav" :torch.randn(x)} for x in item_lengths} + >>> dataset = DynamicItemDataset(dataset) + >>> dataset.set_output_keys(["wav"]) + >>> length_func = lambda x : len(x) # trivial in this example + >>> bsampler = DynamicBatchSampler(dataset, 20, 4, length_func, shuffle=False, batch_ordering='descending') + >>> dataloader = SaveableDataLoader(dataset, batch_sampler=bsampler, collate_fn=PaddedBatch) + >>> for i, b in enumerate(dataloader): + ... data, length = b["wav"] + >>> assert data.shape[-1] == max(item_lengths) + + Arguments + --------- + dataset : torch.utils.data.Dataset + Pytorch Dataset from which elements will be sampled. + max_batch_length : int + Upper limit for the sum of the length of examples in a batch. + Should be chosen based on your GPU memory. + num_buckets : int + Number of discrete buckets used to group examples together. + If num_buckets == 1, all examples can be batched together. As the number of buckets grows only examples with similar + length can be grouped together. This trades-off speed with randomization. + Low number -> better randomization, High number -> faster training. + However if set too high the training speed will decrease. If num_buckets -> number of examples in the dataset the batch size + will be small impacting training speed and possibly performance. + NOTE: you have either to specify manually the bucket_boundaries or the number of buckets. + length_func : callable + Function used to get length of each example from the dataset. + This argument can be used only when the dataset is a Speechbrain DynamicItemDataset object. + Can be anything: e.g. lambda x: x["duration"]*16000 returns number of samples + if duration key in the annotation is in seconds and the file has 16kHz sampling freq. + shuffle : bool + Whether or not shuffle examples between each epoch. + batch_ordering : string + If ``random``, batches are randomly permuted; otherwise ``ascending`` or ``descending`` sorted by length. + max_batch_ex: int + If set, it limits the maximum number of examples that can be in a batch superseeding max_batch_length + in instances where the amount of examples will exceeed the value specified here. + E.g. you have a lot of short examples and the batch size for those will be too high, you can use this argument + to limit the batch size for these short examples. + bucket_boundaries : list + Overrides bucket_length_multiplier and left_bucket_length by specifying manually + the buckets right boundaries. + lengths_list: list + Overrides length_func by passing a list containing the length of each example + in the dataset. This argument must be set when the dataset is a plain + Pytorch Dataset object and not a DynamicItemDataset object as length_func + cannot be used on Pytorch Datasets. + epoch : int + The epoch to start at. + drop_last : bool + If ``True``, the sampler will drop the last examples which + have not been grouped. + verbose: bool + If ``True``, log also the stats for each batch at the first epoch. + """ + + def __init__( + self, + dataset, + args, + num_replicas = None, + rank = None, + shuffle = True, + seed = 0, + drop_last = False, + length_func=lambda x: x["duration"], + batch_ordering: str = "random", + max_batch_ex: int = None, + bucket_boundaries: List[int] = [], + lengths_list: List[int] = None, + epoch: int = 0, + verbose: bool = False, + ): + self.args = args + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + if rank >= num_replicas or rank < 0: + raise ValueError( + "Invalid rank {}, rank should be in the interval" + " [0, {}]".format(rank, num_replicas - 1)) + self.num_replicas = num_replicas + self.rank = rank + max_batch_length = self.args.max_num_tokens if dataset.split == "train" else self.args.val_max_num_tokens + logging.info(f"max_num_tokens per GPU for {dataset.split} split: {max_batch_length}") + num_buckets = self.args.num_buckets + ############# + + + + + self._dataset = dataset + self._ex_lengths = {} + # ex_ids = self._dataset.data_ids + self.verbose = verbose + + # We do not put a default on num_buckets to encourage users to play with this parameter + if num_buckets is None and len(bucket_boundaries) == 0: + raise RuntimeError( + "Please specify either num_buckets or bucket boundaries." + "Check the docs, and/or the tutorial !" + ) + assert lengths_list != None + max_len = int(self.args.audio_max_length * self.args.encodec_sr) + lengths_list = [min(l, max_len) for l in lengths_list] # replace all utt whose length is longer than max_len to max_len, will also do this in __getitem__ in dataset + for indx in range(len(lengths_list)): + self._ex_lengths[str(indx)] = lengths_list[indx] + # if lengths_list is not None: + # # take length of examples from this argument and bypass length_key + # for indx in range(len(lengths_list)): + # self._ex_lengths[str(indx)] = lengths_list[indx] + # else: + # # use length func + # if not isinstance(dataset, DynamicItemDataset): + # raise NotImplementedError( + # "Dataset should be a Speechbrain DynamicItemDataset when using length function" + # ) + # for indx in range(len(self._dataset)): + # self._ex_lengths[str(indx)] = length_func( + # self._dataset.data[ex_ids[indx]] + # ) + + if len(bucket_boundaries) > 0: + if not all([x >= 0 for x in bucket_boundaries]): + raise ValueError( + "All elements in bucket boundaries should be non-negative (>= 0)." + ) + if not len(set(bucket_boundaries)) == len(bucket_boundaries): + raise ValueError( + "Bucket_boundaries should not contain duplicates." + ) + np.testing.assert_array_equal( + np.array(bucket_boundaries), + np.array(sorted(bucket_boundaries)), + err_msg="The arg bucket_boundaries should be an ascending sorted list of non negative values values!", + ) + self._bucket_boundaries = np.array(sorted(bucket_boundaries)) + else: + # use num_buckets + self._bucket_boundaries = np.array( + self._get_boundaries_through_warping( + # max_batch_length=max_batch_length, + max_batch_length=max(lengths_list), + num_quantiles=num_buckets, + ) + ) + + self._max_batch_length = max_batch_length + self._shuffle_ex = shuffle + self._batch_ordering = batch_ordering + self._seed = seed + self._drop_last = drop_last + if max_batch_ex is None: + max_batch_ex = np.inf + self._max_batch_ex = max_batch_ex + # Calculate bucket lengths - how often does one bucket boundary fit into max_batch_length? + self._bucket_lens = [ + max(1, int(max_batch_length / self._bucket_boundaries[i])) + for i in range(len(self._bucket_boundaries)) + ] + [1] + self._epoch = epoch + self._cur_step = 0 + self.continue_flag = False + self._generate_batches() + self.num_samples = int(math.floor(len(self._batches) / self.num_replicas)) + self.total_size = int(self.num_samples * self.num_replicas) + self._replica_batches = self._batches[self.rank:self.total_size:self.num_replicas] + assert len(self._replica_batches) == self.num_samples, f"len(self._batches): {len(self._batches)}, self.total_size: {self.total_size}, self.num_samples: {self.num_samples},len(self._replica_batches): {len(self._replica_batches)}" + logging.info(f"len(self._batches): {len(self._batches)}") + logging.info(f"self.num_replicas: {self.num_replicas}") + logging.info(f"num of batches on each replica: {self.num_samples}") + + def get_durations(self, batch): + """Gets durations of the elements in the batch.""" + return [self._ex_lengths[str(idx)] for idx in batch] + + def _get_boundaries_through_warping( + self, max_batch_length: int, num_quantiles: int, + ) -> List[int]: + + # NOTE: the following lines do not cover that there is only one example in the dataset + # warp frames (duration) distribution of train data + logging.info("Batch quantisation in latent space") + # linspace set-up + num_boundaries = num_quantiles + 1 + # create latent linearly equal spaced buckets + latent_boundaries = np.linspace( + 1 / num_boundaries, num_quantiles / num_boundaries, num_quantiles, + ) + # get quantiles using lognormal distribution + quantiles = lognorm.ppf(latent_boundaries, 1) + # scale up to to max_batch_length + bucket_boundaries = quantiles * max_batch_length / quantiles[-1] + # compute resulting bucket length multipliers + length_multipliers = [ + bucket_boundaries[x + 1] / bucket_boundaries[x] + for x in range(num_quantiles - 1) + ] + # logging + logging.debug( + "Latent bucket boundary - buckets: {} - length multipliers: {}".format( + list(map("{:.2f}".format, bucket_boundaries)), + list(map("{:.2f}".format, length_multipliers)), + ) + ) + return list(sorted(bucket_boundaries)) + + def _permute_batches(self): + + if self._batch_ordering == "random": + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self._seed + self._epoch) # since the random seed is based on self._seed and self._epoch, it should be the same for different processes when using DDP, and therefore the generated order should be the same across different process, this is important, because each replica will only take a portion of it, we want to make sure they take a non-overlapping portion, and all of them constitute the entire dataset + sampler = torch.randperm( + len(self._batches), generator=g + ).tolist() # type: ignore + tmp = [] + for idx in sampler: + tmp.append(self._batches[idx]) + self._batches = tmp + + elif self._batch_ordering == "ascending": + self._batches = sorted( + self._batches, + key=lambda x: max([self._ex_lengths[str(idx)] for idx in x]), + ) + elif self._batch_ordering == "descending": + self._batches = sorted( + self._batches, + key=lambda x: max([self._ex_lengths[str(idx)] for idx in x]), + reverse=True, + ) + else: + raise NotImplementedError + + def _generate_batches(self): + logging.info("DynamicBatchSampler: Generating dynamic batches") + if self._shuffle_ex: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self._seed + self._epoch) # since the random seed is based on self._seed and self._epoch, it should be the same for different processes when using DDP, and therefore the generated order should be the same across different process, this is important, because each replica will only take a portion of it, we want to make sure they take a non-overlapping portion, and all of them constitute the entire dataset + sampler = torch.randperm(len(self._dataset), generator=g).tolist() # type: ignore + # pyp note: this is actually randomly permoted indices + else: + # take examples as they are: e.g. they have been sorted + sampler = range(len(self._dataset)) # type: ignore + + self._batches = [] + bucket_batches = [[] for i in self._bucket_lens] + + stats_tracker = [ + {"min": np.inf, "max": -np.inf, "tot": 0, "n_ex": 0} + for i in self._bucket_lens + ] + + for idx in sampler: + # length of pre-sampled audio + item_len = self._ex_lengths[str(idx)] + # bucket to fill up most padding + bucket_id = np.searchsorted(self._bucket_boundaries, item_len) + # fill audio's duration into that bucket + bucket_batches[bucket_id].append(idx) + + stats_tracker[bucket_id]["min"] = min( + stats_tracker[bucket_id]["min"], item_len + ) + stats_tracker[bucket_id]["max"] = max( + stats_tracker[bucket_id]["max"], item_len + ) + stats_tracker[bucket_id]["tot"] += item_len + stats_tracker[bucket_id]["n_ex"] += 1 + # track #samples - why not duration/#frames; rounded up? + # keep track of durations, if necessary + + if ( + len(bucket_batches[bucket_id]) >= self._bucket_lens[bucket_id] + or len(bucket_batches[bucket_id]) >= self._max_batch_ex + ): + self._batches.append(bucket_batches[bucket_id]) + bucket_batches[bucket_id] = [] + # keep track of durations + + # Dump remaining batches + if not self._drop_last: + for batch in bucket_batches: + if batch: + self._batches.append(batch) + + self._permute_batches() # possibly reorder batches + + if self._epoch == 0: # only log at first epoch + # frames per batch & their padding remaining + boundaries = [0] + self._bucket_boundaries.tolist() + + for bucket_indx in range(len(self._bucket_boundaries)): + try: + num_batches = stats_tracker[bucket_indx]["tot"] // ( + self._max_batch_length + ) + pad_factor = ( + stats_tracker[bucket_indx]["max"] + - stats_tracker[bucket_indx]["min"] + ) / ( + stats_tracker[bucket_indx]["tot"] + / stats_tracker[bucket_indx]["n_ex"] + ) + except ZeroDivisionError: + num_batches = 0 + pad_factor = 0 + + logging.debug( + ( + "DynamicBatchSampler: Bucket {} with boundary {:.1f}-{:.1f} and " + + "batch_size {}: Num Examples {:.1f}, Num Full Batches {:.3f}, Pad Factor {:.3f}." + ).format( + bucket_indx, + boundaries[bucket_indx], + boundaries[bucket_indx + 1], + self._bucket_lens[bucket_indx], + stats_tracker[bucket_indx]["n_ex"], + num_batches, + pad_factor * 100, + ) + ) + + if self.verbose: + batch_stats = { + "tot_frames": [], + "tot_pad_frames": [], + "pad_%": [], + } + for batch in self._batches: + tot_frames = sum( + [self._ex_lengths[str(idx)] for idx in batch] + ) + batch_stats["tot_frames"].append(tot_frames) + max_frames = max( + [self._ex_lengths[str(idx)] for idx in batch] + ) + tot_pad = sum( + [ + max_frames - self._ex_lengths[str(idx)] + for idx in batch + ] + ) + batch_stats["tot_pad_frames"].append(tot_pad) + batch_stats["pad_%"].append(tot_pad / tot_frames * 100) + + padding_details = "Batch {} with {:.1f} frames with {} files - {:.1f} padding, {:.2f} (%) of total." + padding_details = "DynamicBatchSampler: " + padding_details + for i in range(len(self._batches)): + logging.debug( + padding_details.format( + i, + batch_stats["tot_frames"][i], + len(self._batches[i]), + batch_stats["tot_pad_frames"][i], + batch_stats["pad_%"][i], + ) + ) + + def __iter__(self): + + for batch in self._replica_batches: + yield batch + + + # if self._shuffle_ex: # re-generate examples if ex_ordering == "random" + # self._generate_batches() + # if self._batch_ordering == "random": + # # we randomly permute the batches only --> faster + # self._permute_batches() + + def set_epoch(self, epoch): + """ + You can also just access self.epoch, but we maintain this interface + to mirror torch.utils.data.distributed.DistributedSampler + """ + self._epoch = epoch + self._generate_batches() + self._replica_batches = self._batches[self.rank:self.total_size:self.num_replicas] + self.num_samples = int(math.floor(len(self._batches) / self.num_replicas)) + assert len(self._replica_batches) == self.num_samples, f"len(self._batches): {len(self._batches)}, self.total_size: {self.total_size}, self.num_samples: {self.num_samples},len(self._replica_batches): {len(self._replica_batches)}" + + if self.continue_flag: + self.continue_flag = False + self._replica_batches = self._replica_batches[self._cur_step:] + self.num_samples = len(self._replica_batches) + + + def __len__(self): + return self.num_samples + + def set_epoch_resume(self, epoch, cur_step): + self.continue_flag = True + self._epoch = epoch + self._cur_step = cur_step diff --git a/z_scripts/e830M.sh b/z_scripts/e830M.sh new file mode 100644 index 0000000..5394e83 --- /dev/null +++ b/z_scripts/e830M.sh @@ -0,0 +1,71 @@ +#!/bin/bash +source ~/miniconda3/etc/profile.d/conda.sh +conda activate voicecraft +export CUDA_VISIBLE_DEVICES=0,1,2,3 +export WORLD_SIZE=4 + +dataset=gigaspeech +mkdir -p ./logs/${dataset} + +exp_root="/data/scratch/pyp/exp_pyp/VoiceCraft" +exp_name=e830M +dataset_dir="/data/scratch/pyp/datasets/gigaspeech_phn_enc_manifest/xl" +encodec_codes_folder_name="encodec_16khz_4codebooks" + +# export CUDA_LAUNCH_BLOCKING=1 # for debugging + +torchrun --nnodes=1 --rdzv-backend=c10d --rdzv-endpoint=localhost:41977 --nproc_per_node=${WORLD_SIZE} \ +../main.py \ +--reduced_eog 1 \ +--drop_long 1 \ +--eos 2051 \ +--n_special 4 \ +--pad_x 0 \ +--codebook_weight "[5,1,0.5,0.1]" \ +--encodec_sr 50 \ +--num_steps 50000 \ +--lr 0.05 \ +--warmup_fraction 0.01 \ +--optimizer_name "ScaledAdam" \ +--pseudo_epoch_size 3000 \ +--reduce_lr_start_step 3000 \ +--reduce_lr_start_epoch 4 \ +--clipping_update_period 1000 \ +--d_model 2048 \ +--audio_embedding_dim 2048 \ +--nhead 16 \ +--num_decoder_layers 16 \ +--max_num_tokens 100000 \ +--gradient_accumulation_steps 26 \ +--val_max_num_tokens 6000 \ +--num_buckets 6 \ +--audio_max_length 20 \ +--audio_min_length 2 \ +--text_max_length 400 \ +--text_min_length 10 \ +--mask_len_min 1 \ +--mask_len_max 600 \ +--tb_write_every_n_steps 10 \ +--print_every_n_steps 400 \ +--val_every_n_steps 1600 \ +--text_vocab_size 100 \ +--text_pad_token 100 \ +--phn_folder_name "phonemes" \ +--manifest_name "manifest_large16khz_lessambi" \ +--encodec_folder_name ${encodec_codes_folder_name} \ +--audio_vocab_size 2048 \ +--empty_token 2048 \ +--eog 2049 \ +--audio_pad_token 2050 \ +--n_codebooks 4 \ +--max_n_spans 3 \ +--shuffle_mask_embedding 0 \ +--mask_sample_dist poisson1 \ +--max_mask_portion 0.9 \ +--min_gap 5 \ +--num_workers 8 \ +--dynamic_batching 1 \ +--dataset $dataset \ +--exp_dir "${exp_root}/${dataset}/${exp_name}" \ +--dataset_dir ${dataset_dir} +# >> ./logs/${dataset}/${exp_name}.log 2>&1 \ No newline at end of file