This commit is contained in:
jason-on-salt-a40 2024-03-21 11:02:20 -07:00
commit 6760f29bd0
32 changed files with 9321 additions and 0 deletions

25
.gitignore vendored Normal file
View File

@ -0,0 +1,25 @@
__pycache__/
*.py[cod]
*$py.class
*.egg-info
.pytest_cache
.ipynb_checkpoints
thumbs.db
.DS_Store
.idea
*.log
*.pdf
*.mkv
*.mp4
*.png
*.wav
*.mp3
*durip*
*rtx*
*l40*
*a40*
!/demo/
!/demo/*

437
LICENSE-CODE Normal file
View File

@ -0,0 +1,437 @@
Attribution-NonCommercial-ShareAlike 4.0 International
=======================================================================
Creative Commons Corporation ("Creative Commons") is not a law firm and
does not provide legal services or legal advice. Distribution of
Creative Commons public licenses does not create a lawyer-client or
other relationship. Creative Commons makes its licenses and related
information available on an "as-is" basis. Creative Commons gives no
warranties regarding its licenses, any material licensed under their
terms and conditions, or any related information. Creative Commons
disclaims all liability for damages resulting from their use to the
fullest extent possible.
Using Creative Commons Public Licenses
Creative Commons public licenses provide a standard set of terms and
conditions that creators and other rights holders may use to share
original works of authorship and other material subject to copyright
and certain other rights specified in the public license below. The
following considerations are for informational purposes only, are not
exhaustive, and do not form part of our licenses.
Considerations for licensors: Our public licenses are
intended for use by those authorized to give the public
permission to use material in ways otherwise restricted by
copyright and certain other rights. Our licenses are
irrevocable. Licensors should read and understand the terms
and conditions of the license they choose before applying it.
Licensors should also secure all rights necessary before
applying our licenses so that the public can reuse the
material as expected. Licensors should clearly mark any
material not subject to the license. This includes other CC-
licensed material, or material used under an exception or
limitation to copyright. More considerations for licensors:
wiki.creativecommons.org/Considerations_for_licensors
Considerations for the public: By using one of our public
licenses, a licensor grants the public permission to use the
licensed material under specified terms and conditions. If
the licensor's permission is not necessary for any reason--for
example, because of any applicable exception or limitation to
copyright--then that use is not regulated by the license. Our
licenses grant only permissions under copyright and certain
other rights that a licensor has authority to grant. Use of
the licensed material may still be restricted for other
reasons, including because others have copyright or other
rights in the material. A licensor may make special requests,
such as asking that all changes be marked or described.
Although not required by our licenses, you are encouraged to
respect those requests where reasonable. More considerations
for the public:
wiki.creativecommons.org/Considerations_for_licensees
=======================================================================
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
Public License
By exercising the Licensed Rights (defined below), You accept and agree
to be bound by the terms and conditions of this Creative Commons
Attribution-NonCommercial-ShareAlike 4.0 International Public License
("Public License"). To the extent this Public License may be
interpreted as a contract, You are granted the Licensed Rights in
consideration of Your acceptance of these terms and conditions, and the
Licensor grants You such rights in consideration of benefits the
Licensor receives from making the Licensed Material available under
these terms and conditions.
Section 1 -- Definitions.
a. Adapted Material means material subject to Copyright and Similar
Rights that is derived from or based upon the Licensed Material
and in which the Licensed Material is translated, altered,
arranged, transformed, or otherwise modified in a manner requiring
permission under the Copyright and Similar Rights held by the
Licensor. For purposes of this Public License, where the Licensed
Material is a musical work, performance, or sound recording,
Adapted Material is always produced where the Licensed Material is
synched in timed relation with a moving image.
b. Adapter's License means the license You apply to Your Copyright
and Similar Rights in Your contributions to Adapted Material in
accordance with the terms and conditions of this Public License.
c. BY-NC-SA Compatible License means a license listed at
creativecommons.org/compatiblelicenses, approved by Creative
Commons as essentially the equivalent of this Public License.
d. Copyright and Similar Rights means copyright and/or similar rights
closely related to copyright including, without limitation,
performance, broadcast, sound recording, and Sui Generis Database
Rights, without regard to how the rights are labeled or
categorized. For purposes of this Public License, the rights
specified in Section 2(b)(1)-(2) are not Copyright and Similar
Rights.
e. Effective Technological Measures means those measures that, in the
absence of proper authority, may not be circumvented under laws
fulfilling obligations under Article 11 of the WIPO Copyright
Treaty adopted on December 20, 1996, and/or similar international
agreements.
f. Exceptions and Limitations means fair use, fair dealing, and/or
any other exception or limitation to Copyright and Similar Rights
that applies to Your use of the Licensed Material.
g. License Elements means the license attributes listed in the name
of a Creative Commons Public License. The License Elements of this
Public License are Attribution, NonCommercial, and ShareAlike.
h. Licensed Material means the artistic or literary work, database,
or other material to which the Licensor applied this Public
License.
i. Licensed Rights means the rights granted to You subject to the
terms and conditions of this Public License, which are limited to
all Copyright and Similar Rights that apply to Your use of the
Licensed Material and that the Licensor has authority to license.
j. Licensor means the individual(s) or entity(ies) granting rights
under this Public License.
k. NonCommercial means not primarily intended for or directed towards
commercial advantage or monetary compensation. For purposes of
this Public License, the exchange of the Licensed Material for
other material subject to Copyright and Similar Rights by digital
file-sharing or similar means is NonCommercial provided there is
no payment of monetary compensation in connection with the
exchange.
l. Share means to provide material to the public by any means or
process that requires permission under the Licensed Rights, such
as reproduction, public display, public performance, distribution,
dissemination, communication, or importation, and to make material
available to the public including in ways that members of the
public may access the material from a place and at a time
individually chosen by them.
m. Sui Generis Database Rights means rights other than copyright
resulting from Directive 96/9/EC of the European Parliament and of
the Council of 11 March 1996 on the legal protection of databases,
as amended and/or succeeded, as well as other essentially
equivalent rights anywhere in the world.
n. You means the individual or entity exercising the Licensed Rights
under this Public License. Your has a corresponding meaning.
Section 2 -- Scope.
a. License grant.
1. Subject to the terms and conditions of this Public License,
the Licensor hereby grants You a worldwide, royalty-free,
non-sublicensable, non-exclusive, irrevocable license to
exercise the Licensed Rights in the Licensed Material to:
a. reproduce and Share the Licensed Material, in whole or
in part, for NonCommercial purposes only; and
b. produce, reproduce, and Share Adapted Material for
NonCommercial purposes only.
2. Exceptions and Limitations. For the avoidance of doubt, where
Exceptions and Limitations apply to Your use, this Public
License does not apply, and You do not need to comply with
its terms and conditions.
3. Term. The term of this Public License is specified in Section
6(a).
4. Media and formats; technical modifications allowed. The
Licensor authorizes You to exercise the Licensed Rights in
all media and formats whether now known or hereafter created,
and to make technical modifications necessary to do so. The
Licensor waives and/or agrees not to assert any right or
authority to forbid You from making technical modifications
necessary to exercise the Licensed Rights, including
technical modifications necessary to circumvent Effective
Technological Measures. For purposes of this Public License,
simply making modifications authorized by this Section 2(a)
(4) never produces Adapted Material.
5. Downstream recipients.
a. Offer from the Licensor -- Licensed Material. Every
recipient of the Licensed Material automatically
receives an offer from the Licensor to exercise the
Licensed Rights under the terms and conditions of this
Public License.
b. Additional offer from the Licensor -- Adapted Material.
Every recipient of Adapted Material from You
automatically receives an offer from the Licensor to
exercise the Licensed Rights in the Adapted Material
under the conditions of the Adapter's License You apply.
c. No downstream restrictions. You may not offer or impose
any additional or different terms or conditions on, or
apply any Effective Technological Measures to, the
Licensed Material if doing so restricts exercise of the
Licensed Rights by any recipient of the Licensed
Material.
6. No endorsement. Nothing in this Public License constitutes or
may be construed as permission to assert or imply that You
are, or that Your use of the Licensed Material is, connected
with, or sponsored, endorsed, or granted official status by,
the Licensor or others designated to receive attribution as
provided in Section 3(a)(1)(A)(i).
b. Other rights.
1. Moral rights, such as the right of integrity, are not
licensed under this Public License, nor are publicity,
privacy, and/or other similar personality rights; however, to
the extent possible, the Licensor waives and/or agrees not to
assert any such rights held by the Licensor to the limited
extent necessary to allow You to exercise the Licensed
Rights, but not otherwise.
2. Patent and trademark rights are not licensed under this
Public License.
3. To the extent possible, the Licensor waives any right to
collect royalties from You for the exercise of the Licensed
Rights, whether directly or through a collecting society
under any voluntary or waivable statutory or compulsory
licensing scheme. In all other cases the Licensor expressly
reserves any right to collect such royalties, including when
the Licensed Material is used other than for NonCommercial
purposes.
Section 3 -- License Conditions.
Your exercise of the Licensed Rights is expressly made subject to the
following conditions.
a. Attribution.
1. If You Share the Licensed Material (including in modified
form), You must:
a. retain the following if it is supplied by the Licensor
with the Licensed Material:
i. identification of the creator(s) of the Licensed
Material and any others designated to receive
attribution, in any reasonable manner requested by
the Licensor (including by pseudonym if
designated);
ii. a copyright notice;
iii. a notice that refers to this Public License;
iv. a notice that refers to the disclaimer of
warranties;
v. a URI or hyperlink to the Licensed Material to the
extent reasonably practicable;
b. indicate if You modified the Licensed Material and
retain an indication of any previous modifications; and
c. indicate the Licensed Material is licensed under this
Public License, and include the text of, or the URI or
hyperlink to, this Public License.
2. You may satisfy the conditions in Section 3(a)(1) in any
reasonable manner based on the medium, means, and context in
which You Share the Licensed Material. For example, it may be
reasonable to satisfy the conditions by providing a URI or
hyperlink to a resource that includes the required
information.
3. If requested by the Licensor, You must remove any of the
information required by Section 3(a)(1)(A) to the extent
reasonably practicable.
b. ShareAlike.
In addition to the conditions in Section 3(a), if You Share
Adapted Material You produce, the following conditions also apply.
1. The Adapter's License You apply must be a Creative Commons
license with the same License Elements, this version or
later, or a BY-NC-SA Compatible License.
2. You must include the text of, or the URI or hyperlink to, the
Adapter's License You apply. You may satisfy this condition
in any reasonable manner based on the medium, means, and
context in which You Share Adapted Material.
3. You may not offer or impose any additional or different terms
or conditions on, or apply any Effective Technological
Measures to, Adapted Material that restrict exercise of the
rights granted under the Adapter's License You apply.
Section 4 -- Sui Generis Database Rights.
Where the Licensed Rights include Sui Generis Database Rights that
apply to Your use of the Licensed Material:
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
to extract, reuse, reproduce, and Share all or a substantial
portion of the contents of the database for NonCommercial purposes
only;
b. if You include all or a substantial portion of the database
contents in a database in which You have Sui Generis Database
Rights, then the database in which You have Sui Generis Database
Rights (but not its individual contents) is Adapted Material,
including for purposes of Section 3(b); and
c. You must comply with the conditions in Section 3(a) if You Share
all or a substantial portion of the contents of the database.
For the avoidance of doubt, this Section 4 supplements and does not
replace Your obligations under this Public License where the Licensed
Rights include other Copyright and Similar Rights.
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
c. The disclaimer of warranties and limitation of liability provided
above shall be interpreted in a manner that, to the extent
possible, most closely approximates an absolute disclaimer and
waiver of all liability.
Section 6 -- Term and Termination.
a. This Public License applies for the term of the Copyright and
Similar Rights licensed here. However, if You fail to comply with
this Public License, then Your rights under this Public License
terminate automatically.
b. Where Your right to use the Licensed Material has terminated under
Section 6(a), it reinstates:
1. automatically as of the date the violation is cured, provided
it is cured within 30 days of Your discovery of the
violation; or
2. upon express reinstatement by the Licensor.
For the avoidance of doubt, this Section 6(b) does not affect any
right the Licensor may have to seek remedies for Your violations
of this Public License.
c. For the avoidance of doubt, the Licensor may also offer the
Licensed Material under separate terms or conditions or stop
distributing the Licensed Material at any time; however, doing so
will not terminate this Public License.
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
License.
Section 7 -- Other Terms and Conditions.
a. The Licensor shall not be bound by any additional or different
terms or conditions communicated by You unless expressly agreed.
b. Any arrangements, understandings, or agreements regarding the
Licensed Material not stated herein are separate from and
independent of the terms and conditions of this Public License.
Section 8 -- Interpretation.
a. For the avoidance of doubt, this Public License does not, and
shall not be interpreted to, reduce, limit, restrict, or impose
conditions on any use of the Licensed Material that could lawfully
be made without permission under this Public License.
b. To the extent possible, if any provision of this Public License is
deemed unenforceable, it shall be automatically reformed to the
minimum extent necessary to make it enforceable. If the provision
cannot be reformed, it shall be severed from this Public License
without affecting the enforceability of the remaining terms and
conditions.
c. No term or condition of this Public License will be waived and no
failure to comply consented to unless expressly agreed to by the
Licensor.
d. Nothing in this Public License constitutes or may be interpreted
as a limitation upon, or waiver of, any privileges and immunities
that apply to the Licensor or You, including from the legal
processes of any jurisdiction or authority.
=======================================================================
Creative Commons is not a party to its public
licenses. Notwithstanding, Creative Commons may elect to apply one of
its public licenses to material it publishes and in those instances
will be considered the “Licensor.” The text of the Creative Commons
public licenses is dedicated to the public domain under the CC0 Public
Domain Dedication. Except for the limited purpose of indicating that
material is shared under a Creative Commons public license or as
otherwise permitted by the Creative Commons policies published at
creativecommons.org/policies, Creative Commons does not authorize the
use of the trademark "Creative Commons" or any other trademark or logo
of Creative Commons without its prior written consent including,
without limitation, in connection with any unauthorized modifications
to any of its public licenses or any other arrangements,
understandings, or agreements concerning use of licensed material. For
the avoidance of doubt, this paragraph does not form part of the
public licenses.
Creative Commons may be contacted at creativecommons.org.

42
LICENSE-MODEL Normal file
View File

@ -0,0 +1,42 @@
Coqui Public Model License 1.0.0
https://coqui.ai/cpml.txt
This license allows only non-commercial use of a machine learning model and its outputs.
Acceptance
In order to get any license under these terms, you must agree to them as both strict obligations and conditions to all your licenses.
Licenses
The licensor grants you a copyright license to do everything you might do with the model that would otherwise infringe the licensor's copyright in it, for any non-commercial purpose. The licensor grants you a patent license that covers patent claims the licensor can license, or becomes able to license, that you would infringe by using the model in the form provided by the licensor, for any non-commercial purpose.
Non-commercial Purpose
Non-commercial purposes include any of the following uses of the model or its output, but only so far as you do not receive any direct or indirect payment arising from the use of the model or its output.
Personal use for research, experiment, and testing for the benefit of public knowledge, personal study, private entertainment, hobby projects, amateur pursuits, or religious observance.
Use by commercial or for-profit entities for testing, evaluation, or non-commercial research and development. Use of the model to train other models for commercial use is not a non-commercial purpose.
Use by any charitable organization for charitable purposes, or for testing or evaluation. Use for revenue-generating activity, including projects directly funded by government grants, is not a non-commercial purpose.
Notices
You must ensure that anyone who gets a copy of any part of the model, or any modification of the model, or their output, from you also gets a copy of these terms or the URL for them above.
No Other Rights
These terms do not allow you to sublicense or transfer any of your licenses to anyone else, or prevent the licensor from granting licenses to anyone else. These terms do not imply any other licenses.
Patent Defense
If you make any written claim that the model infringes or contributes to infringement of any patent, your licenses for the model granted under these terms ends immediately. If your company makes such a claim, your patent license ends immediately for work on behalf of your company.
Violations
The first time you are notified in writing that you have violated any of these terms, or done anything with the model or its output that is not covered by your licenses, your licenses can nonetheless continue if you come into full compliance with these terms, and take practical steps to correct past violations, within 30 days of receiving notice. Otherwise, all your licenses end immediately.
No Liability
AS FAR AS THE LAW ALLOWS, THE MODEL AND ITS OUTPUT COME AS IS, WITHOUT ANY WARRANTY OR CONDITION, AND THE LICENSOR WILL NOT BE LIABLE TO YOU FOR ANY DAMAGES ARISING OUT OF THESE TERMS OR THE USE OR NATURE OF THE MODEL OR ITS OUTPUT, UNDER ANY KIND OF LEGAL CLAIM. IF THIS PROVISION IS NOT ENFORCEABLE IN YOUR JURISDICTION, YOUR LICENSES ARE VOID.
Definitions
The licensor is the individual or entity offering these terms, and the model is the model the licensor makes available under these terms, including any documentation or similar information about the model.
You refers to the individual or entity agreeing to these terms.
Your company is any legal entity, sole proprietorship, or other kind of organization that you work for, plus all organizations that have control over, are under the control of, or are under common control with that organization. Control means ownership of substantially all the assets of an entity, or the power to direct its management and policies by vote, contract, or otherwise. Control can be direct or indirect.
Your licenses are all the licenses granted to you under these terms.
Use means anything you do with the model or its output requiring one of your licenses.

69
README.md Normal file
View File

@ -0,0 +1,69 @@
# VoiceCraft: Zero-Shot Speech Editing and Text-to-Speech in the Wild
[Demo](https://jasonppy.github.io/VoiceCraft_web) [Paper](https://jasonppy.github.io/assets/pdfs/VoiceCraft.pdf)
TL;DR:
VoiceCraft is a token infilling neural codec language model, that achieves state-of-the-art performance on both **speech editing** and **zero-shot text-to-speech (TTS)** on in-the-wild data including audiobooks, internet videos, and podcasts.
To clone or edit an unseen voice, VoiceCraft needs only a few seconds of reference.
## TODO
The TODOs left will be completed by the end of March 2024.
- [x] Codebase upload
- [x] Environment setup
- [x] Inference demo for speech editing and TTS
- [] Upload model weights
- [] Training guidance
- [] Upload the RealEdit dataset
## Environment setup
```bash
conda create -n voicecraft python=3.9.16
conda activate voicecraft
pip install torch==2.0.1 torchaudio==2.0.2 # this assumes your system is compatible with CUDA 11.7, otherwise checkout https://pytorch.org/get-started/previous-versions/#v201
apt-get install ffmpeg # if you don't already have ffmpeg installed
pip install -e git+https://github.com/facebookresearch/audiocraft.git@c5157b5bf14bf83449c17ea1eeb66c19fb4bc7f0#egg=audiocraft
apt-get install espeak-ng # backend for the phonemizer installed below
pip install phonemizer==3.2.1
pip install tensorboard
pip install datasets==2.12.0
# install MFA for getting forced-alignment, this could take a few minutes
conda install -c conda-forge montreal-forced-aligner=2.2.17 openfst=1.8.2 kaldi=5.5.1068
# conda install pocl # above gives an warning for installing pocl, not sure if really need this
# to run ipynb
conda install -n voicecraft ipykernel --update-deps --force-reinstall
```
## Inference Examples
Checkout [`inference_speech_editing.ipynb`](./inference_speech_editing.ipynb) and [`inference_tts.ipynb`](./inference_tts.ipynb)
## License
The codebase is under CC BY-NC-SA 4.0 ([LICENSE-CODE](./LICENSE-CODE)), and the model weights are under Coqui Public Model License 1.0.0 ([LICENSE-MODEL](./LICENSE-MODEL)). Note that we use some of the code from other repository that are under different licenses: `./models/codebooks_patterns.py` is under MIT license; `./models/modules`, `./steps/optim.py`, `data/tokenizer.py` are under Apache License, Version 2.0; the phonemizer we used is under GNU 3.0 License. For drop-in replacement of the phonemizer (i.e. text to IPA phoneme mapping), try [g2p](https://github.com/roedoejet/g2p) (MIT License) or [OpenPhonemizer](https://github.com/NeuralVox/OpenPhonemizer) (BSD-3-Clause Clear), although these are not tested.
<!-- How to use g2p to convert english text into IPA phoneme sequence
first install it with `pip install g2p`
```python
from g2p import make_g2p
transducer = make_g2p('eng', 'eng-ipa')
transducer("hello").output_string
# it will output: 'hʌloʊ'
``` -->
## Acknowledgement
We thank Feiteng for his [VALL-E reproduction](https://github.com/lifeiteng/vall-e), and we thank audiocraft team for open-sourcing [encodec](https://github.com/facebookresearch/audiocraft).
## Citation
```
@article{peng2024voicecraft,
author = {Peng, Puyuan and Huang, Po-Yao and Li, Daniel and Mohamed, Abdelrahman and Harwath, David},
title = {VoiceCraft: Zero-Shot Speech Editing and Text-to-Speech in the Wild},
journal = {arXiv},
year = {2024},
}
```
## Disclaimer
Any organization or individual is prohibited from using any technology mentioned in this paper to generate or edit someone's speech without his/her consent, including but not limited to government leaders, political figures, and celebrities. If you do not comply with this item, you could be in violation of copyright laws.

86
config.py Normal file
View File

@ -0,0 +1,86 @@
import argparse
def MyParser():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# general training
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--precision", type=str, default="float16")
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--resume", action="store_true", default=False)
parser.add_argument("--tb_write_every_n_steps", type=int, default=100)
parser.add_argument("--print_every_n_steps", type=int, default=400)
parser.add_argument("--val_every_n_steps", type=int, default=800)
parser.add_argument("--lr", type=float, default=0.05)
parser.add_argument("--batch_size", type=int, default=100, help="this is the effective batch size, no matter whether using gradient_accumulation_steps, not used if we specified max_num_tokens")
parser.add_argument("--max_num_tokens", type=int, default=100000, help="max number of encodec tokens per gpu, this is only used when using dynamic batching, will ignore batch size. Note this is the final effective batch size per GPU, i.e. gradient accumulated batch size per gpu")
parser.add_argument("--val_max_num_tokens", type=int, default=None, help="FOR validation")
parser.add_argument("--num_buckets", type=int, default=6, help='used for dynamic batching, bucketing the samples based on the number of tokens')
parser.add_argument("--dynamic_batching", type=int, default=0)
parser.add_argument("--weight_decay", type=float, default=1e-2)
parser.add_argument("--warmup_fraction", type=float, default=0.01, help="use linear warmup, the proportion of the training steps that are used for warming up")
parser.add_argument("--num_epochs", type=int, default=10)
parser.add_argument("--num_steps", type=int, default=None, help="if not None, will ignore n_epochs and use num_steps as the total number of amount of training, can try e.g. 400000 i.e. 400k steps")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--gradient_clip_val", type=float, default=1.0, help="the value for torch.nn.utils.clip_grad_norm_(), not used if we use ScaledAdam optimizer")
parser.add_argument("--early_stop_step", type=int, default=3200, help="stop training after this many steps of non-improvement")
parser.add_argument("--early_stop_threshold", type=float, default=-1.0, help="early stop after the improvement is below this threshold for certain number of steps")
# optimizer focused
parser.add_argument("--optimizer_name", type=str, default="AdamW", help="can also use ScaledAdam, in which case we'll also use the Eden scheduler")
parser.add_argument("--reduce_lr_start_step", type=int, default=3000, help='after which significantly reduce the lr. a param for the eden optimizer')
parser.add_argument("--pseudo_epoch_size", type=int, default=3000, help="only use for Eden scheduler.")
parser.add_argument("--reduce_lr_start_epoch", type=int, default=4)
parser.add_argument("--clipping_update_period", type=int, default=600)
# path
parser.add_argument("--exp_dir", type=str, default=None, help="will be combined with dataset name")
parser.add_argument("--dataset", type=str, help="e.g. 'libritts', 'gigaspeech', they are folder name in the data dir also")
parser.add_argument("--dataset_dir", type=str, help="need to be compatible with corresponding dataset py file")
parser.add_argument("--phn_folder_name", type=str, default="phonemes", help="for libritts I also have arpa phns, in which case should be phonemes_arpa")
parser.add_argument("--encodec_folder_name", type=str, default="encodec_16khz_4codebooks", help="folder where encodec codes are stored")
parser.add_argument("--manifest_name", type=str, default="manifest", help="metadata filename")
# data focused
parser.add_argument("--pad_x", type=int, default=1, help="whether or not always pad x to have text_max_length. select 1 to get the maximal memory consumption, but the actual case should be smaller, better to have it being 0")
parser.add_argument("--audio_max_length", type=float, default=20, help="in second, crop or drop the audio is length is longer than this")
parser.add_argument("--audio_min_length", type=float, default=2, help="in second, drop the audio if length is shorter than this")
parser.add_argument("--text_max_length", type=int, default=400, help='if too long, we crop or drop')
parser.add_argument("--text_min_length", type=float, default=10, help="if too short, will drop")
parser.add_argument("--encodec_sr", type=int, default=50, help="for my encodec that takes 16kHz audio with a downsample rate of 320, the codec sample rate is 50Hz, i.e. 50 codes (x n_codebooks) per second")
parser.add_argument("--drop_long", type=int, default=0, help="if this is true, will drop example whose encodec sequence or phone sequence is too long, rather than cropping, to reduce hellucination")
# encodec and token rearrangement
parser.add_argument('--mask_len_min', type=int, default=1, help='Minimum mask length')
parser.add_argument('--mask_len_max', type=int, default=600, help='Maximum mask length')
parser.add_argument("--eos", type=int, default=-1, help="this is to be used with reduced_eog, where we end the utterance with eos, and end the generated segment with eog, also when this is used, the n_special should be 4")
parser.add_argument("--reduced_eog", type=int, default=0, help="for the non-final segments, do not insert eog at the end, this could hopefully solve the early stopping issue when doing tts")
parser.add_argument("--special_first", type=int, default=0, help="if 1, need to have special tokens to be the first few tokens, e.g. 0, 1, 2, which means we need to adjust the preprocessing and postprocessing of the encodec codes. note that we hard coded to have 3 special tokens")
parser.add_argument("--n_special", type=int, default=3, help="empty, eog, pad, (eos)")
parser.add_argument("--codebook_weight", type=str, default=None, help="e.g. ['5','1','0.5','0.1']")
parser.add_argument("--max_mask_portion",type=float,default=0.7,help="should mask a utterance for more than this portion")
parser.add_argument("--max_n_spans", type=int, default=3, help='maximal number of spans, only use when using multicm3, this is used to decide number of mask_embedding, and max clamp value if use Poisson distribution, if use uniform distribution to sample number of spans if will be uniform(1,max_n_spans)')
parser.add_argument("--shuffle_mask_embedding", type=int, default=0, help="whether shuffle the mask embedding, so that mask:0 is not the most well trained, default is not shuffling. The default has it's benefit, as it make sure that mask:0 always appear the first")
parser.add_argument("--mask_sample_dist", type=str, default="poisson1", help="uniform or poissonx, e.g. poisson1, meaning the parameter lambda is 1, it will most likely sample 1 masks")
parser.add_argument("--min_gap", type=int, default=5, help="after sampled starts, delete later one if it closer to the former start than the min_gap")
parser.add_argument('--n_codebooks', type=int, default=4)
parser.add_argument('--text_vocab_size', type=int, default=100, help='Size of text vocabulary')
parser.add_argument('--text_pad_token', type=int, default=100, help='padding of the text tokens, not attended')
parser.add_argument('--audio_vocab_size', type=str, default='2048', help="Size of audio vocabulary")
parser.add_argument("--empty_token", default=2048, type=int, help="indicating the no token at the position for the codebook")
parser.add_argument('--eog', type=int, default=2049, help='End of generation token')
parser.add_argument('--audio_pad_token', type=int, default=2050, help='padding of the encodec codes, not attended')
# model focused
parser.add_argument('--d_model', type=int, default=2048, help='Model dimension')
parser.add_argument('--audio_embedding_dim', type=int, default=2048, help='dimension for encodec continues embedding (before being quantized)')
parser.add_argument('--text_embedding_dropout', type=float, default=0.1, help='Dropout for text embedding')
parser.add_argument('--audio_embedding_dropout', type=float, default=0, help='Dropout for audio embedding')
parser.add_argument('--text_positional_embedding_dropout', type=float, default=0.1, help='Dropout for text positional embedding')
parser.add_argument('--audio_positional_embedding_dropout', type=float, default=0.1, help='Dropout for audio positional embedding')
parser.add_argument('--trm_dropout', type=float, default=0.1, help='Dropout for transformer')
parser.add_argument('--nhead', type=int, default=16, help='Number of attention heads')
parser.add_argument('--num_decoder_layers', type=int, default=16, help='Number of decoder layers')
parser.add_argument('--load_model_from', type=str, default=None, help='Path to load model from, this will be effective last, so will overwrite all previous load, including resume')
return parser

0
data/__init__.py Normal file
View File

View File

@ -0,0 +1,160 @@
import argparse
def parse_args():
parser = argparse.ArgumentParser(description="encode the librilight dataset using encodec model")
parser.add_argument("--manifest_root", type=str, default="/home/pyp/audiocraft/egs/gigaspeech", help="this the dir of the audiocraft manifest!")
parser.add_argument('--audio_dir', type=str, default="/data/scratch/pyp/datasets/gigaspeech_flac", help="Path dirs of the flac audio files")
parser.add_argument('--save_dir', type=str, default="/data/scratch/pyp/datasets/gigaspeech_phn_enc_manifest/xl", help="path to the manifest, phonemes, and encodec codes dirs")
parser.add_argument('--encodec_model_path', type=str, default="/data/scratch/pyp/exp_pyp/audiocraft/encodec/xps/6f79c6a8/checkpoint.th")
parser.add_argument('--n_workers', type=int, default=32, help="Number of parallel worker processes")
parser.add_argument('--batch_size', type=int, default=64, help="batch size for encodec encoding, decrease it if OOM. This is the sum of batch size *over each gpu*, so increase it if you are using more gpus")
parser.add_argument('--model_sr', type=int, default=16000, help='encodec input audio sample rate')
parser.add_argument('--downsample_rate', type=int, default=320, help='encodec downsample rate')
parser.add_argument('--model_code_sr', type=int, default=50, help='encodec model code sample rate')
parser.add_argument('--len_cap', type=float, default=35.0, help='will drop audios that are longer than this number')
return parser.parse_args()
if __name__ == "__main__":
import logging
formatter = (
"%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
import os
import numpy as np
import torch
import torchaudio
import tqdm
import time
args = parse_args()
manifest_dir = args.manifest_root # this dir is scp-ed
audio_dir = args.audio_dir # this is scp-ed flac dir
encodec_signature = args.encodec_model_path.split("/")[-2]
save_codes_dir = os.path.join(args.save_dir, f"encodec_16khz_{encodec_signature}")
os.makedirs(save_codes_dir, exist_ok=True)
# model_sr = 16000
# downsample_rate = 320
# model_code_sr = 50
def sort_by_audio_len(lens):
inds = np.argsort(lens).tolist()
logging.info(f"longest: {lens[inds[-1]]/args.downsample_rate} encodec codes, {lens[inds[-1]]/args.model_sr:.2f} sec.")
logging.info(f"shortest: {lens[inds[0]]/args.downsample_rate} encodec codes, {lens[inds[0]]/args.model_sr:.2f} sec.")
logging.info(f"median: {lens[inds[len(inds)//2]]/args.downsample_rate} encodec codes, {lens[inds[len(inds)//2]]/args.model_sr:.2f} sec.")
logging.info(f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]/args.downsample_rate} encodec codes, {lens[inds[int(len(inds)*0.95)]]/args.model_sr:.2f} sec.")
return inds[::-1]
def write_array_to_txt_file(array, filename):
with open(filename, 'w') as f:
for a in array[:-1]:
f.write(' '.join(map(str, a))+'\n')
f.write(' '.join(map(str, array[-1])))
class mydataset(torch.utils.data.Dataset):
def __init__(self, split):
super().__init__()
# self.data = gs[split]
self.split = split
self.audio_root = audio_dir
manifest_fn = os.path.join(manifest_dir, split+".txt")
with open(manifest_fn, "r") as rf:
self.data = [l.strip().split("\t") for l in rf.readlines()]
def __len__(self):
return len(self.data)
def __getitem__(self, ind):
try:
afn = self.data[ind][0]
fn = os.path.join(self.audio_root, afn)
audio, sr = torchaudio.load(fn)
assert sr == args.model_sr, sr
except Exception as e:
logging.info(f"{e}")
return None, None, None
assert audio.ndim==2 and audio.shape[0] == 1, audio.shape
return audio.type(torch.float32).squeeze(0), audio.shape[-1], os.path.basename(afn).split(".")[0]
def collate(self, batch):
lens, audios, segment_ids = [], [], []
for item in batch:
if item[0] != None:
audios.append(item[0])
lens.append(item[1])
segment_ids.append(item[2])
return audios, lens, segment_ids
# load the encodec model
from audiocraft.solvers import CompressionSolver
model = CompressionSolver.model_from_checkpoint(args.encodec_model_path)
model = model.cuda()
model = model.eval()
model = torch.nn.DataParallel(model)
# setup dataloader
mega_batch_size = 2100
batch_size = args.batch_size
train_dataset = mydataset('train')
train_loader = torch.torch.utils.data.DataLoader(train_dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=train_dataset.collate)
validation_dataset = mydataset('validation')
validation_loader = torch.torch.utils.data.DataLoader(validation_dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=validation_dataset.collate)
test_dataset = mydataset('test')
test_loader = torch.torch.utils.data.DataLoader(test_dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=test_dataset.collate)
splits = ['validation', 'test', 'train']
loaders = [validation_loader, test_loader, train_loader]
# splits = ['validation'] # NOTE this is for debug, for example, see if the
# loaders = [validation_loader]
for split, loader in zip(splits, loaders):
skip = 0
logging.info(f"now processing split {split}...")
mega_n_steps = int(np.ceil(len(loader.dataset) / mega_batch_size))
# mega_n_steps = int(np.ceil(len(gs) / mega_batch_size))
logging.info(f"partition the split {split} into {mega_n_steps} parts, each has {mega_batch_size} samples")
# with open(mani_fn, "a") as mani_wf: # resume from where we failed
for m, mega_batch in enumerate(loader):
logging.info(f"====================================")
logging.info(f"====================================")
logging.info(f"now processing mega step {m+1}/{mega_n_steps}")
lengths = np.array(mega_batch[1])
sorted_inds = sort_by_audio_len(lengths)
for j in range(len(sorted_inds))[::-1]:
if lengths[sorted_inds[j]] < args.model_sr*0.2 or lengths[sorted_inds[j]] > args.model_sr*args.len_cap: # skip samples that are too short (shorter than 0.2s), or too big (bigger than 80s)
skip += 1
del sorted_inds[j]
n_steps = int(np.ceil(len(sorted_inds) / batch_size))
for n in tqdm.tqdm(range(n_steps), disable=True):
inds_used = sorted_inds[n*batch_size:(n+1)*batch_size]
wav_batch = [mega_batch[0][id] for id in inds_used]
all_lens = [mega_batch[1][id] for id in inds_used]
segment_id_batch = [mega_batch[2][id] for id in inds_used]
# print(segment_id_batch)
padded_wav = torch.nn.utils.rnn.pad_sequence(wav_batch, batch_first=True).unsqueeze(1) # [B, T] -> [B, 1, T]
with torch.no_grad():
if max(all_lens) > 300000 and len(all_lens) > 1: # NOTE decrease this (300000) if OOM, or chunk it into more than 2 forward passes
codes = []
inwav = padded_wav.cuda()
codes.append(model(inwav[:len(inwav)//2], encode=True)[0].cpu())
codes.append(model(inwav[len(inwav)//2:], encode=True)[0].cpu())
codes = torch.cat(codes, dim=0)
else:
encoded_frames = model(padded_wav.cuda(), encode=True) # wav needs to have shape [B, C, T], C is model.channels, which is 1 for the 24kHz encodec model
# logging.info(f"encoded_frames: {encoded_frames[0].shape}")
codes = encoded_frames[0].cpu()
for i, length in enumerate(all_lens):
save_fn = os.path.join(save_codes_dir, segment_id_batch[i]+".txt")
actual_len = round(length / args.downsample_rate) # 320 is downsample rate for this model
cur_code = codes[i].tolist() if type(codes) == list else codes[i, :, :actual_len].tolist()
write_array_to_txt_file(cur_code, save_fn)
# mani_wf.write(f"0\t{segment_id_batch[i]}\t{len(cur_code[0])}\n") # write to manifest file
# if i == 10:
# raise
# break
# logging.info(f"split {split} has {len(gs[split])} samples in total, skipped {skip} due to forbiden words")
logging.info(f"split {split} has {len(loader.dataset)} samples in total, skipped {skip} due to utterance being too long or too short")
# break

158
data/gigaspeech.py Normal file
View File

@ -0,0 +1,158 @@
import os
import torch
import random
import copy
import logging
import shutil
class dataset(torch.utils.data.Dataset):
def __init__(self, args, split):
super().__init__()
self.args = args
self.split = split
assert self.split in ['train', 'validation', 'test']
manifest_fn = os.path.join(self.args.dataset_dir, self.args.manifest_name, self.split+".txt")
with open(manifest_fn, "r") as rf:
data = [l.strip().split("\t") for l in rf.readlines()]
lengths_list = [int(item[-1]) for item in data]
self.data = []
self.lengths_list = []
for d, l in zip(data, lengths_list):
if l >= self.args.encodec_sr*self.args.audio_min_length:
if self.args.drop_long and l > self.args.encodec_sr*self.args.audio_max_length:
continue
self.data.append(d)
self.lengths_list.append(l)
logging.info(f"number of data points for {self.split} split: {len(self.lengths_list)}")
# phoneme vocabulary
vocab_fn = os.path.join(self.args.dataset_dir,"vocab.txt")
shutil.copy(vocab_fn, os.path.join(self.args.exp_dir, "vocab.txt"))
with open(vocab_fn, "r") as f:
temp = [l.strip().split(" ") for l in f.readlines() if len(l) != 0]
self.phn2num = {item[1]:int(item[0]) for item in temp}
self.symbol_set = set(["<SIL>", "<MUSIC>", "<NOISE>", "<OTHER>"])
def __len__(self):
return len(self.lengths_list)
def _load_phn_enc(self, index):
item = self.data[index]
pf = os.path.join(self.args.dataset_dir, self.args.phn_folder_name, item[1]+".txt")
ef = os.path.join(self.args.dataset_dir, self.args.encodec_folder_name, item[1]+".txt")
try:
with open(pf, "r") as p, open(ef, "r") as e:
phns = [l.strip() for l in p.readlines()]
assert len(phns) == 1, phns
x = [self.phn2num[item] for item in phns[0].split(" ") if item not in self.symbol_set] # drop ["<SIL>", "<MUSIC>", "<NOISE>", "<OTHER>"], as they are not in training set annotation
encos = [l.strip().split() for k, l in enumerate(e.readlines()) if k < self.args.n_codebooks]
assert len(encos) == self.args.n_codebooks, ef
if self.args.special_first:
y = [[int(n)+self.args.n_special for n in l] for l in encos]
else:
y = [[int(n) for n in l] for l in encos]
if self.args.training_stage == 1 and not self.args.valle and not (self.args.musicgen or self.args.valle_orig):
y = y[:1]
except Exception as e:
logging.info(f"loading failed for {pf} and {ef}, maybe files don't exist or are corrupted")
logging.info(f"error message: {e}")
return [], [[]]
return x, y
def __getitem__(self, index):
x, y = self._load_phn_enc(index)
x_len, y_len = len(x), len(y[0])
if x_len == 0 or y_len == 0:
return {
"x": None,
"x_len": None,
"y": None,
"y_len": None,
"y_mask_interval": None, # index y_mask_interval[1] is the position of start_of_continue token
"extra_mask_start": None # this is only used in VE1
}
while y_len < self.args.encodec_sr*self.args.audio_min_length:
assert not self.args.dynamic_batching
index = random.choice(range(len(self))) # regenerate an index
x, y = self._load_phn_enc(index)
x_len, y_len = len(x), len(y[0])
if self.args.drop_long:
while x_len > self.args.text_max_length or y_len > self.args.encodec_sr*self.args.audio_max_length:
index = random.choice(range(len(self))) # regenerate an index
x, y = self._load_phn_enc(index)
x_len, y_len = len(x), len(y[0])
### padding and cropping below ###
### padding and cropping below ###
# adjust the length of encodec codes, pad to max_len or randomly crop
orig_y_len = copy.copy(y_len)
max_len = int(self.args.audio_max_length * self.args.encodec_sr)
if y_len > max_len:
audio_start = random.choice(range(0, y_len-max_len))
for i in range(len(y)):
y[i] = y[i][audio_start:(audio_start+max_len)]
y_len = max_len
else:
audio_start = 0
if not self.args.dynamic_batching:
pad = [0] * (max_len - y_len) if self.args.sep_special_token else [self.args.audio_pad_token] * (max_len - y_len)
for i in range(len(y)):
y[i] = y[i] + pad
# adjust text
# if audio is cropped, and text is longer than max, crop max based on how audio is cropped
if audio_start > 0 and len(x) > self.args.text_max_length: # if audio is longer than max and text is long than max, start text the way audio started
x = x[int(len(x)*audio_start/orig_y_len):]
if len(x) > self.args.text_max_length: # if text is still longer than max, cut the end
x = x[:self.args.text_max_length]
x_len = len(x)
if x_len > self.args.text_max_length:
text_start = random.choice(range(0, x_len - self.args.text_max_length))
x = x[text_start:text_start+self.args.text_max_length]
x_len = self.args.text_max_length
elif self.args.pad_x and x_len <= self.args.text_max_length:
pad = [0] * (self.args.text_max_length - x_len) if self.args.sep_special_token else [self.args.text_pad_token] * (self.args.text_max_length - x_len)
x = x + pad
### padding and cropping above ###
### padding and cropping above ###
return {
"x": torch.LongTensor(x),
"x_len": x_len,
"y": torch.LongTensor(y),
"y_len": y_len
}
def collate(self, batch):
out = {key:[] for key in batch[0]}
for item in batch:
if item['x'] == None: # deal with load failure
continue
for key, val in item.items():
out[key].append(val)
res = {}
if self.args.pad_x:
res["x"] = torch.stack(out["x"], dim=0)
else:
res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=0 if self.args.sep_special_token else self.args.text_pad_token)
res["x_lens"] = torch.LongTensor(out["x_len"])
if self.args.dynamic_batching:
if out['y'][0].ndim==2:
res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=0 if self.args.sep_special_token else self.args.audio_pad_token)
res['y'] = res['y'].permute(1,2,0) # T B K -> B K T
else:
assert out['y'][0].ndim==1, out['y'][0].shape
res['y'] = torch.nn.utils.rnn.pad_sequence(out['y'], batch_first=True, padding_value=0 if self.args.sep_special_token else self.args.audio_pad_token)
else:
res['y'] = torch.stack(out['y'], dim=0)
res["y_lens"] = torch.LongTensor(out["y_len"])
res["text_padding_mask"] = torch.arange(res['x'][0].shape[-1]).unsqueeze(0) >= res['x_lens'].unsqueeze(1)
res["audio_padding_mask"] = torch.arange(res['y'][0].shape[-1]).unsqueeze(0) >= res['y_lens'].unsqueeze(1)
return res

149
data/tokenizer.py Normal file
View File

@ -0,0 +1,149 @@
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/data/tokenizer.py
# Copyright 2023 (authors: Feiteng Li)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Optional, Pattern, Union
import numpy as np
import torch
import torchaudio
# from lhotse.features import FeatureExtractor
# from lhotse.utils import Seconds, compute_num_frames
from phonemizer.backend import EspeakBackend
from phonemizer.backend.espeak.language_switch import LanguageSwitch
from phonemizer.backend.espeak.words_mismatch import WordMismatch
from phonemizer.punctuation import Punctuation
from phonemizer.separator import Separator
class TextTokenizer:
"""Phonemize Text."""
def __init__(
self,
language="en-us",
backend="espeak",
separator=Separator(word="_", syllable="-", phone="|"),
preserve_punctuation=True,
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
with_stress: bool = False,
tie: Union[bool, str] = False,
language_switch: LanguageSwitch = "keep-flags",
words_mismatch: WordMismatch = "ignore",
) -> None:
phonemizer = EspeakBackend(
language,
punctuation_marks=punctuation_marks,
preserve_punctuation=preserve_punctuation,
with_stress=with_stress,
tie=tie,
language_switch=language_switch,
words_mismatch=words_mismatch,
)
self.backend = phonemizer
self.separator = separator
def to_list(self, phonemized: str) -> List[str]:
fields = []
for word in phonemized.split(self.separator.word):
# "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
fields.extend(
[p for p in pp if p != self.separator.phone]
+ [self.separator.word]
)
assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
self.separator.phone
)
return fields[:-1]
def __call__(self, text, strip=True) -> List[List[str]]:
if isinstance(text, str):
text = [text]
phonemized = self.backend.phonemize(
text, separator=self.separator, strip=strip, njobs=1
)
return [self.to_list(p) for p in phonemized]
def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
phonemes = tokenizer([text.strip()])
return phonemes[0] # k2symbols
def convert_audio(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int):
assert wav.shape[0] in [1, 2], "Audio must be mono or stereo."
if target_channels == 1:
wav = wav.mean(0, keepdim=True)
elif target_channels == 2:
*shape, _, length = wav.shape
wav = wav.expand(*shape, target_channels, length)
elif wav.shape[0] == 1:
wav = wav.expand(target_channels, -1)
wav = torchaudio.transforms.Resample(sr, target_sr)(wav)
return wav
class AudioTokenizer:
"""EnCodec audio."""
def __init__(
self,
device: Any = None,
signature = None
) -> None:
from audiocraft.solvers import CompressionSolver
model = CompressionSolver.model_from_checkpoint(signature)
self.sample_rate = model.sample_rate
self.channels = model.channels
if not device:
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda:0")
self._device = device
self.codec = model.to(device)
@property
def device(self):
return self._device
def encode(self, wav: torch.Tensor) -> torch.Tensor:
codes = self.codec.encode(wav.to(self.device))
return [(codes[0], None)]
def decode(self, frames: torch.Tensor) -> torch.Tensor:
frames = frames[0][0] # [1,4,T]
return self.codec.decode(frames)
def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str, offset = -1, num_frames=-1):
# Load and pre-process the audio waveform
if offset != -1 and num_frames!=-1:
wav, sr = torchaudio.load(audio_path, frame_offset=offset, num_frames=num_frames)
else:
wav, sr = torchaudio.load(audio_path)
wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
wav = wav.unsqueeze(0)
# Extract discrete codes from EnCodec
with torch.no_grad():
encoded_frames = tokenizer.encode(wav)
return encoded_frames

Binary file not shown.

View File

@ -0,0 +1 @@
But when I had approached so near to them The common object, which the sense deceives, Lost not by distance any of its marks,

View File

@ -0,0 +1,109 @@
Begin,End,Label,Type,Speaker
0.03,0.18,but,words,temp
0.18,0.32,when,words,temp
0.32,0.49,i,words,temp
0.49,0.64,had,words,temp
0.64,1.19,approached,words,temp
1.22,1.58,so,words,temp
1.58,1.9,near,words,temp
1.9,2.07,to,words,temp
2.07,2.42,them,words,temp
2.53,2.61,the,words,temp
2.61,3.01,common,words,temp
3.05,3.62,object,words,temp
3.68,3.93,which,words,temp
3.93,4.02,the,words,temp
4.02,4.34,sense,words,temp
4.34,4.97,deceives,words,temp
5.04,5.54,lost,words,temp
5.54,6.0,not,words,temp
6.0,6.14,by,words,temp
6.14,6.67,distance,words,temp
6.79,7.06,any,words,temp
7.06,7.18,of,words,temp
7.18,7.34,its,words,temp
7.34,7.87,marks,words,temp
0.03,0.06,B,phones,temp
0.06,0.09,AH1,phones,temp
0.09,0.18,T,phones,temp
0.18,0.23,W,phones,temp
0.23,0.27,EH1,phones,temp
0.27,0.32,N,phones,temp
0.32,0.49,AY1,phones,temp
0.49,0.5,HH,phones,temp
0.5,0.6,AE1,phones,temp
0.6,0.64,D,phones,temp
0.64,0.7,AH0,phones,temp
0.7,0.83,P,phones,temp
0.83,0.87,R,phones,temp
0.87,0.99,OW1,phones,temp
0.99,1.12,CH,phones,temp
1.12,1.19,T,phones,temp
1.22,1.4,S,phones,temp
1.4,1.58,OW1,phones,temp
1.58,1.7,N,phones,temp
1.7,1.84,IH1,phones,temp
1.84,1.9,R,phones,temp
1.9,2.01,T,phones,temp
2.01,2.07,AH0,phones,temp
2.07,2.13,DH,phones,temp
2.13,2.3,EH1,phones,temp
2.3,2.42,M,phones,temp
2.53,2.55,DH,phones,temp
2.55,2.61,AH0,phones,temp
2.61,2.73,K,phones,temp
2.73,2.85,AA1,phones,temp
2.85,2.9,M,phones,temp
2.9,2.95,AH0,phones,temp
2.95,3.01,N,phones,temp
3.05,3.22,AA1,phones,temp
3.22,3.27,B,phones,temp
3.27,3.34,JH,phones,temp
3.34,3.48,EH0,phones,temp
3.48,3.54,K,phones,temp
3.54,3.62,T,phones,temp
3.68,3.69,HH,phones,temp
3.69,3.76,W,phones,temp
3.76,3.8,IH1,phones,temp
3.8,3.93,CH,phones,temp
3.93,3.95,DH,phones,temp
3.95,4.02,AH0,phones,temp
4.02,4.12,S,phones,temp
4.12,4.21,EH1,phones,temp
4.21,4.27,N,phones,temp
4.27,4.34,S,phones,temp
4.34,4.42,D,phones,temp
4.42,4.45,IH0,phones,temp
4.45,4.59,S,phones,temp
4.59,4.8,IY1,phones,temp
4.8,4.87,V,phones,temp
4.87,4.97,Z,phones,temp
5.04,5.12,L,phones,temp
5.12,5.33,AO1,phones,temp
5.33,5.42,S,phones,temp
5.42,5.54,T,phones,temp
5.54,5.7,N,phones,temp
5.7,5.89,AA1,phones,temp
5.89,6.0,T,phones,temp
6.0,6.05,B,phones,temp
6.05,6.14,AY1,phones,temp
6.14,6.24,D,phones,temp
6.24,6.3,IH1,phones,temp
6.3,6.38,S,phones,temp
6.38,6.45,T,phones,temp
6.45,6.51,AH0,phones,temp
6.51,6.57,N,phones,temp
6.57,6.67,S,phones,temp
6.79,6.89,EH1,phones,temp
6.89,6.95,N,phones,temp
6.95,7.06,IY0,phones,temp
7.06,7.13,AH0,phones,temp
7.13,7.18,V,phones,temp
7.18,7.22,IH0,phones,temp
7.22,7.29,T,phones,temp
7.29,7.34,S,phones,temp
7.34,7.39,M,phones,temp
7.39,7.49,AA1,phones,temp
7.49,7.58,R,phones,temp
7.58,7.69,K,phones,temp
7.69,7.87,S,phones,temp
1 Begin End Label Type Speaker
2 0.03 0.18 but words temp
3 0.18 0.32 when words temp
4 0.32 0.49 i words temp
5 0.49 0.64 had words temp
6 0.64 1.19 approached words temp
7 1.22 1.58 so words temp
8 1.58 1.9 near words temp
9 1.9 2.07 to words temp
10 2.07 2.42 them words temp
11 2.53 2.61 the words temp
12 2.61 3.01 common words temp
13 3.05 3.62 object words temp
14 3.68 3.93 which words temp
15 3.93 4.02 the words temp
16 4.02 4.34 sense words temp
17 4.34 4.97 deceives words temp
18 5.04 5.54 lost words temp
19 5.54 6.0 not words temp
20 6.0 6.14 by words temp
21 6.14 6.67 distance words temp
22 6.79 7.06 any words temp
23 7.06 7.18 of words temp
24 7.18 7.34 its words temp
25 7.34 7.87 marks words temp
26 0.03 0.06 B phones temp
27 0.06 0.09 AH1 phones temp
28 0.09 0.18 T phones temp
29 0.18 0.23 W phones temp
30 0.23 0.27 EH1 phones temp
31 0.27 0.32 N phones temp
32 0.32 0.49 AY1 phones temp
33 0.49 0.5 HH phones temp
34 0.5 0.6 AE1 phones temp
35 0.6 0.64 D phones temp
36 0.64 0.7 AH0 phones temp
37 0.7 0.83 P phones temp
38 0.83 0.87 R phones temp
39 0.87 0.99 OW1 phones temp
40 0.99 1.12 CH phones temp
41 1.12 1.19 T phones temp
42 1.22 1.4 S phones temp
43 1.4 1.58 OW1 phones temp
44 1.58 1.7 N phones temp
45 1.7 1.84 IH1 phones temp
46 1.84 1.9 R phones temp
47 1.9 2.01 T phones temp
48 2.01 2.07 AH0 phones temp
49 2.07 2.13 DH phones temp
50 2.13 2.3 EH1 phones temp
51 2.3 2.42 M phones temp
52 2.53 2.55 DH phones temp
53 2.55 2.61 AH0 phones temp
54 2.61 2.73 K phones temp
55 2.73 2.85 AA1 phones temp
56 2.85 2.9 M phones temp
57 2.9 2.95 AH0 phones temp
58 2.95 3.01 N phones temp
59 3.05 3.22 AA1 phones temp
60 3.22 3.27 B phones temp
61 3.27 3.34 JH phones temp
62 3.34 3.48 EH0 phones temp
63 3.48 3.54 K phones temp
64 3.54 3.62 T phones temp
65 3.68 3.69 HH phones temp
66 3.69 3.76 W phones temp
67 3.76 3.8 IH1 phones temp
68 3.8 3.93 CH phones temp
69 3.93 3.95 DH phones temp
70 3.95 4.02 AH0 phones temp
71 4.02 4.12 S phones temp
72 4.12 4.21 EH1 phones temp
73 4.21 4.27 N phones temp
74 4.27 4.34 S phones temp
75 4.34 4.42 D phones temp
76 4.42 4.45 IH0 phones temp
77 4.45 4.59 S phones temp
78 4.59 4.8 IY1 phones temp
79 4.8 4.87 V phones temp
80 4.87 4.97 Z phones temp
81 5.04 5.12 L phones temp
82 5.12 5.33 AO1 phones temp
83 5.33 5.42 S phones temp
84 5.42 5.54 T phones temp
85 5.54 5.7 N phones temp
86 5.7 5.89 AA1 phones temp
87 5.89 6.0 T phones temp
88 6.0 6.05 B phones temp
89 6.05 6.14 AY1 phones temp
90 6.14 6.24 D phones temp
91 6.24 6.3 IH1 phones temp
92 6.3 6.38 S phones temp
93 6.38 6.45 T phones temp
94 6.45 6.51 AH0 phones temp
95 6.51 6.57 N phones temp
96 6.57 6.67 S phones temp
97 6.79 6.89 EH1 phones temp
98 6.89 6.95 N phones temp
99 6.95 7.06 IY0 phones temp
100 7.06 7.13 AH0 phones temp
101 7.13 7.18 V phones temp
102 7.18 7.22 IH0 phones temp
103 7.22 7.29 T phones temp
104 7.29 7.34 S phones temp
105 7.34 7.39 M phones temp
106 7.39 7.49 AA1 phones temp
107 7.49 7.58 R phones temp
108 7.58 7.69 K phones temp
109 7.69 7.87 S phones temp

49
edit_utils.py Normal file
View File

@ -0,0 +1,49 @@
def get_span(orig, new, editType):
orig_list = orig.split(" ")
new_list = new.split(" ")
flag = False # this indicate whether the actual edit follow the specified editType
if editType == "deletion":
assert len(orig_list) > len(new_list), f"the edit type is deletion, but new is not shorter than original:\n new: {new}\n orig: {orig}"
diff = len(orig_list) - len(new_list)
for i, (o, n) in enumerate(zip(orig_list, new_list)):
if o != n: # assume the index of the first different word is the starting index of the orig_span
orig_span = [i, i + diff - 1] # assume that the indices are starting and ending index of the deleted part
new_span = [i-1, i] # but for the new span, the starting and ending index is the two words that surround the deleted part
flag = True
break
elif editType == "insertion":
assert len(orig_list) < len(new_list), f"the edit type is insertion, but the new is not longer than the original:\n new: {new}\n orig: {orig}"
diff = len(new_list) - len(orig_list)
for i, (o, n) in enumerate(zip(orig_list, new_list)):
if o != n: # insertion is just the opposite of deletion
new_span = [i, i + diff - 1] # NOTE if only inserted one word, s and e will be the same
orig_span = [i-1, i]
flag = True
break
elif editType == "substitution":
new_span = []
orig_span = []
for i, (o, n) in enumerate(zip(orig_list, new_list)):
if o != n:
new_span = [i]
orig_span = [i]
break
assert len(new_span) == 1 and len(orig_span) == 1, f"new_span: {new_span}, orig_span: {orig_span}"
for j, (o, n) in enumerate(zip(orig_list[::-1], new_list[::-1])):
if o != n:
new_span.append(len(new_list) - j -1)
orig_span.append(len(orig_list) - j - 1)
flag = True
break
else:
raise RuntimeError(f"editType unknown: {editType}")
if not flag:
raise RuntimeError(f"wrong editing with the specified edit type:\n original: {orig}\n new: {new}\n, editType: {editType}")
return orig_span, new_span

View File

@ -0,0 +1,209 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\" \n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/pyp/miniconda3/envs/voicecraft/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"# import libs\n",
"import torch\n",
"import torchaudio\n",
"\n",
"from data.tokenizer import (\n",
" AudioTokenizer,\n",
" TextTokenizer,\n",
")\n",
"\n",
"from models import voicecraft\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# hyperparameters for inference\n",
"left_margin = 0.08\n",
"right_margin = 0.08\n",
"seed = 1\n",
"codec_audio_sr = 16000\n",
"codec_sr = 50\n",
"top_k = 0\n",
"top_p = 0.8\n",
"temperature = 1\n",
"kvcache = 0\n",
"silence_tokens = [1388,1898,131]\n",
"stop_repetition = -1 # do not stop repetition on silence\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"\n",
"# point to the original file or record the file\n",
"# write down the transcript for the file, or run whisper to get the transcript (and you can modify it if it's not accurate), save it as a .txt file\n",
"orig_audio = \"./demo/84_121550_000074_000000.wav\"\n",
"orig_transcript = \"But when I had approached so near to them The common object, which the sense deceives, Lost not by distance any of its marks,\"\n",
"# move the audio and transcript to temp folder\n",
"temp_folder = \"./demo/temp\"\n",
"os.makedirs(temp_folder, exist_ok=True)\n",
"os.system(f\"cp {orig_audio} {temp_folder}\")\n",
"filename = os.path.splitext(orig_audio.split(\"/\")[-1])[0]\n",
"with open(f\"{temp_folder}/{filename}.txt\", \"w\") as f:\n",
" f.write(orig_transcript)\n",
"# run MFA to get the alignment\n",
"align_temp = f\"{temp_folder}/mfa_alignments\"\n",
"os.makedirs(align_temp, exist_ok=True)\n",
"os.system(f\"mfa align -j 1 --output_format csv {temp_folder} english_us_arpa english_us_arpa {align_temp}\")\n",
"# if it fail, it could be because the audio is too hard for the alignment model, increasing the beam size usually solves the issue\n",
"# os.system(f\"mfa align -j 1 --output_format csv {temp_folder} english_us_arpa english_us_arpa {align_temp} --beam 1000 --retry_beam 2000\")\n",
"audio_fn = f\"{temp_folder}/{filename}.wav\"\n",
"transcript_fn = f\"{temp_folder}/{filename}.txt\"\n",
"align_fn = f\"{align_temp}/{filename}.csv\"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:phonemizer:words count mismatch on 300.0% of the lines (3/1)\n"
]
}
],
"source": [
"editTypes_set = set(['substitution', 'insertion', 'deletion'])\n",
"# propose what do you want the target modified transcript to be\n",
"target_transcript = \"But when I saw the mirage of the lake in the distance, which the sense deceives, Lost not by distance any of its marks,\"\n",
"edit_type = \"substitution\"\n",
"assert edit_type in editTypes_set, f\"Invalid edit type {edit_type}. Must be one of {editTypes_set}.\"\n",
"\n",
"# if you want to do a second modification on top of the first one, write down the second modification (target_transcript2, type_of_modification2)\n",
"# make sure the two modification do not overlap, if they do, you need to combine them into one modification\n",
"\n",
"# run the script to turn user input to the format that the model can take\n",
"from edit_utils import get_span\n",
"orig_span, new_span = get_span(orig_transcript, target_transcript, edit_type)\n",
"if orig_span[0] > orig_span[1]:\n",
" RuntimeError(f\"example {audio_fn} failed\")\n",
"if orig_span[0] == orig_span[1]:\n",
" orig_span_save = [orig_span[0]]\n",
"else:\n",
" orig_span_save = orig_span\n",
"if new_span[0] == new_span[1]:\n",
" new_span_save = [new_span[0]]\n",
"else:\n",
" new_span_save = new_span\n",
"\n",
"orig_span_save = \",\".join([str(item) for item in orig_span_save])\n",
"new_span_save = \",\".join([str(item) for item in new_span_save])\n",
"from inference_speech_editing_scale import get_mask_interval\n",
"\n",
"start, end = get_mask_interval(align_fn, orig_span_save, edit_type)\n",
"info = torchaudio.info(audio_fn)\n",
"audio_dur = info.num_frames / info.sample_rate\n",
"morphed_span = (max(start - left_margin, 1/codec_sr), min(end + right_margin, audio_dur)) # in seconds\n",
"\n",
"# span in codec frames\n",
"mask_interval = [[round(morphed_span[0]*codec_sr), round(morphed_span[1]*codec_sr)]]\n",
"mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now\n",
"\n",
"# load model, tokenizer, and other necessary files\n",
"ckpt_fn = \"/data/scratch/pyp/exp_pyp/VoiceCraft/gigaspeech/pretrained_830M/best_bundle.pth\"\n",
"encodec_fn = \"/data/scratch/pyp/exp_pyp/audiocraft/encodec/xps/6f79c6a8/checkpoint.th\"\n",
"ckpt = torch.load(ckpt_fn, map_location=\"cpu\")\n",
"model = voicecraft.VoiceCraft(ckpt[\"config\"])\n",
"model.load_state_dict(ckpt[\"model\"])\n",
"model.to(device)\n",
"model.eval()\n",
"\n",
"phn2num = ckpt['phn2num']\n",
"\n",
"text_tokenizer = TextTokenizer(backend=\"espeak\")\n",
"audio_tokenizer = AudioTokenizer(signature=encodec_fn) # will also put the neural codec model on gpu\n",
"\n",
"# run the model to get the output\n",
"from inference_speech_editing_scale import inference_one_sample\n",
"\n",
"decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, \"codec_audio_sr\": codec_audio_sr, \"codec_sr\": codec_sr, \"silence_tokens\": silence_tokens}\n",
"orig_audio, new_audio = inference_one_sample(model, ckpt[\"config\"], phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_transcript, mask_interval, device, decode_config)\n",
" \n",
"# save segments for comparison\n",
"orig_audio, new_audio = orig_audio[0].cpu(), new_audio[0].cpu()\n",
"# logging.info(f\"length of the resynthesize orig audio: {orig_audio.shape}\")\n",
"\n",
"# output_dir\n",
"output_dir = \"./demo/generated_se\"\n",
"os.makedirs(output_dir, exist_ok=True)\n",
"\n",
"save_fn_new = f\"{output_dir}/{os.path.basename(audio_fn)[:-4]}_new_seed{seed}.wav\"\n",
"\n",
"torchaudio.save(save_fn_new, new_audio, codec_audio_sr)\n",
"\n",
"save_fn_orig = f\"{output_dir}/{os.path.basename(audio_fn)[:-4]}_orig.wav\"\n",
"if not os.path.isfile(save_fn_orig):\n",
" orig_audio, orig_sr = torchaudio.load(audio_fn)\n",
" if orig_sr != codec_audio_sr:\n",
" orig_audio = torchaudio.transforms.Resample(orig_sr, codec_audio_sr)(orig_audio)\n",
" torchaudio.save(save_fn_orig, orig_audio, codec_audio_sr)\n",
"\n",
"# if you get error importing T5 in transformers\n",
"# try \n",
"# pip uninstall Pillow\n",
"# pip install Pillow\n",
"# you are likely to get warning looks like WARNING:phonemizer:words count mismatch on 300.0% of the lines (3/1), this can be safely ignored"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "voicecraft",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -0,0 +1,226 @@
import argparse, pickle
import logging
import os, random
import numpy as np
import torch
import torchaudio
from data.tokenizer import (
AudioTokenizer,
TextTokenizer,
tokenize_audio,
tokenize_text
)
from models import voicecraft
import argparse, time, tqdm
# this script only works for the musicgen architecture
def get_args():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--manifest_fn", type=str, default="path/to/eval_metadata_file")
parser.add_argument("--audio_root", type=str, default="path/to/audio_folder")
parser.add_argument("--exp_dir", type=str, default="path/to/model_folder")
parser.add_argument("--left_margin", type=float, default=0.08, help="extra space on the left to the word boundary")
parser.add_argument("--right_margin", type=float, default=0.08, help="extra space on the right to the word boundary")
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--codec_audio_sr", type=int, default=16000, help='the sample rate of audio that the codec is trained for')
parser.add_argument("--codec_sr", type=int, default=50, help='the sample rate of the codec codes')
parser.add_argument("--top_k", type=int, default=-1, help="sampling param")
parser.add_argument("--top_p", type=float, default=0.8, help="sampling param")
parser.add_argument("--temperature", type=float, default=1.0, help="sampling param")
parser.add_argument("--output_dir", type=str, default=None)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--signature", type=str, default=None, help="path to the encodec model")
parser.add_argument("--stop_repetition", type=int, default=2, help="used for inference, when the number of consecutive repetition of a token is bigger than this, stop it")
parser.add_argument("--kvcache", type=int, default=1, help='if true, use kv cache, which is 4-8x faster than without')
parser.add_argument("--silence_tokens", type=str, default="[1388,1898,131]", help="note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default")
return parser.parse_args()
@torch.no_grad()
def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_text, mask_interval, device, decode_config):
# phonemize
text_tokens = [phn2num[phn] for phn in
tokenize_text(
text_tokenizer, text=target_text.strip()
) if phn in phn2num
]
text_tokens = torch.LongTensor(text_tokens).unsqueeze(0)
text_tokens_lens = torch.LongTensor([text_tokens.shape[-1]])
encoded_frames = tokenize_audio(audio_tokenizer, audio_fn)
original_audio = encoded_frames[0][0].transpose(2,1) # [1,T,K]
assert original_audio.ndim==3 and original_audio.shape[0] == 1 and original_audio.shape[2] == model_args.n_codebooks, original_audio.shape
logging.info(f"with direct encodec encoding before input, original audio length: {original_audio.shape[1]} codec frames, which is {original_audio.shape[1]/decode_config['codec_sr']:.2f} sec.")
# forward
stime = time.time()
encoded_frames = model.inference(
text_tokens.to(device),
text_tokens_lens.to(device),
original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8]
mask_interval=mask_interval.unsqueeze(0).to(device),
top_k=decode_config['top_k'],
top_p=decode_config['top_p'],
temperature=decode_config['temperature'],
stop_repetition=decode_config['stop_repetition'],
kvcache=decode_config['kvcache'],
silence_tokens=eval(decode_config['silence_tokens']) if type(decode_config['silence_tokens']) == str else decode_config['silence_tokens'],
) # output is [1,K,T]
logging.info(f"inference on one sample take: {time.time() - stime:.4f} sec.")
if type(encoded_frames) == tuple:
encoded_frames = encoded_frames[0]
logging.info(f"generated encoded_frames.shape: {encoded_frames.shape}, which is {encoded_frames.shape[-1]/decode_config['codec_sr']} sec.")
# decode (both original and generated)
original_sample = audio_tokenizer.decode(
[(original_audio.transpose(2,1), None)] # [1,T,8] -> [1,8,T]
)
generated_sample = audio_tokenizer.decode(
[(encoded_frames, None)]
)
return original_sample, generated_sample
def get_model(exp_dir, device=None):
with open(os.path.join(exp_dir, "args.pkl"), "rb") as f:
model_args = pickle.load(f)
logging.info("load model weights...")
model = voicecraft.VoiceCraft(model_args)
ckpt_fn = os.path.join(exp_dir, "best_bundle.pth")
ckpt = torch.load(ckpt_fn, map_location='cpu')['model']
phn2num = torch.load(ckpt_fn, map_location='cpu')['phn2num']
model.load_state_dict(ckpt)
del ckpt
logging.info("done loading weights...")
if device == None:
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda:0")
model.to(device)
model.eval()
return model, model_args, phn2num
def get_mask_interval(ali_fn, word_span_ind, editType):
with open(ali_fn, "r") as rf:
data = [l.strip().split(",") for l in rf.readlines()]
data = data[1:]
tmp = word_span_ind.split(",")
s, e = int(tmp[0]), int(tmp[-1])
start = None
for j, item in enumerate(data):
if j == s and item[3] == "words":
if editType == 'insertion':
start = float(item[1])
else:
start = float(item[0])
if j == e and item[3] == "words":
if editType == 'insertion':
end = float(item[0])
else:
end = float(item[1])
assert start != None
break
return (start, end)
if __name__ == "__main__":
def seed_everything(seed):
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
formatter = (
"%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
# args.device = 'cpu'
args.allowed_repeat_tokens = eval(args.allowed_repeat_tokens)
seed_everything(args.seed)
# load model
stime = time.time()
logging.info(f"loading model from {args.exp_dir}")
model, model_args, phn2num = get_model(args.exp_dir)
if not os.path.isfile(model_args.exp_dir):
model_args.exp_dir = args.exp_dir
logging.info(f"loading model done, took {time.time() - stime:.4f} sec")
# setup text and audio tokenizer
text_tokenizer = TextTokenizer(backend="espeak")
audio_tokenizer = AudioTokenizer(signature=args.signature) # will also put the neural codec model on gpu
with open(args.manifest_fn, "r") as rf:
manifest = [l.strip().split("\t") for l in rf.readlines()]
manifest = manifest[1:]
# wav_fn txt_fn alingment_fn num_words word_span_ind
audio_fns = []
target_texts = []
mask_intervals = []
edit_types = []
new_spans = []
orig_spans = []
os.makedirs(args.output_dir, exist_ok=True)
if args.crop_concat:
mfa_temp = f"{args.output_dir}/mfa_temp"
os.makedirs(mfa_temp, exist_ok=True)
for item in manifest:
audio_fn = os.path.join(args.audio_root, item[0])
temp = torchaudio.info(audio_fn)
audio_dur = temp.num_frames/temp.sample_rate
audio_fns.append(audio_fn)
target_text = item[2].split("|")[-1]
edit_types.append(item[5].split("|"))
new_spans.append(item[4].split("|"))
orig_spans.append(item[3].split("|"))
target_texts.append(target_text) # the last transcript is the target
# mi needs to be created from word_ind_span and alignment_fn, along with args.left_margin and args.right_margin
mis = []
all_ind_intervals = item[3].split("|")
editTypes = item[5].split("|")
smaller_indx = []
alignment_fn = os.path.join(args.audio_root, "aligned", item[0].replace(".wav", ".csv"))
if not os.path.isfile(alignment_fn):
alignment_fn = alignment_fn.replace("/aligned/", "/aligned_csv/")
assert os.path.isfile(alignment_fn), alignment_fn
for ind_inter,editType in zip(all_ind_intervals, editTypes):
# print(ind_inter)
mi = get_mask_interval(alignment_fn, ind_inter, editType)
mi = (max(mi[0] - args.left_margin, 1/args.codec_sr), min(mi[1] + args.right_margin, audio_dur)) # in seconds
mis.append(mi)
smaller_indx.append(mi[0])
ind = np.argsort(smaller_indx)
mis = [mis[id] for id in ind]
mask_intervals.append(mis)
for i, (audio_fn, target_text, mask_interval) in enumerate(tqdm.tqdm(zip(audio_fns, target_texts, mask_intervals))):
orig_mask_interval = mask_interval
mask_interval = [[round(cmi[0]*args.codec_sr), round(cmi[1]*args.codec_sr)] for cmi in mask_interval]
# logging.info(f"i: {i}, mask_interval: {mask_interval}")
mask_interval = torch.LongTensor(mask_interval) # [M,2]
orig_audio, new_audio = inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_text, mask_interval, args.device, vars(args))
# save segments for comparison
orig_audio, new_audio = orig_audio[0].cpu(), new_audio[0].cpu()
# logging.info(f"length of the resynthesize orig audio: {orig_audio.shape}")
save_fn_new = f"{args.output_dir}/{os.path.basename(audio_fn)[:-4]}_new_seed{args.seed}.wav"
torchaudio.save(save_fn_new, new_audio, args.codec_audio_sr)
save_fn_orig = f"{args.output_dir}/{os.path.basename(audio_fn)[:-4]}_orig.wav"
if not os.path.isfile(save_fn_orig):
orig_audio, orig_sr = torchaudio.load(audio_fn)
if orig_sr != args.codec_audio_sr:
orig_audio = torchaudio.transforms.Resample(orig_sr, args.codec_audio_sr)(orig_audio)
torchaudio.save(save_fn_orig, orig_audio, args.codec_audio_sr)

182
inference_tts.ipynb Normal file
View File

@ -0,0 +1,182 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\" \n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/pyp/miniconda3/envs/voicecraft/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"# import libs\n",
"import torch\n",
"import torchaudio\n",
"\n",
"from data.tokenizer import (\n",
" AudioTokenizer,\n",
" TextTokenizer,\n",
")\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# hyperparameters for inference\n",
"left_margin = 0.08\n",
"right_margin = 0.08\n",
"seed = 1\n",
"codec_audio_sr = 16000\n",
"codec_sr = 50\n",
"top_k = 0\n",
"top_p = 0.8\n",
"temperature = 1\n",
"kvcache = 0\n",
"silence_tokens=[1388,1898,131]\n",
"# if there are long silence in the generated audio, reduce the stop_repetition to 3, 2 or even 1\n",
"stop_repetition = 2\n",
"# if there are long silence or unnaturally strecthed words, increase sample_batch_size to 2, 3 or even 4\n",
"# what this will do to the model is that the model will run sample_batch_size examples of the same audio, and pick the one that's the shortest\n",
"sample_batch_size = 1\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"\n",
"# point to the original file or record the file\n",
"# write down the transcript for the file, or run whisper to get the transcript (and you can modify it if it's not accurate), save it as a .txt file\n",
"orig_audio = \"/home/pyp/VoiceCraft/demo/84_121550_000074_000000.wav\"\n",
"orig_transcript = \"But when I had approached so near to them The common object, which the sense deceives, Lost not by distance any of its marks,\"\n",
"\n",
"# move the audio and transcript to temp folder\n",
"temp_folder = \"/home/pyp/VoiceCraft/demo/temp\"\n",
"os.makedirs(temp_folder, exist_ok=True)\n",
"os.system(f\"cp {orig_audio} {temp_folder}\")\n",
"filename = os.path.splitext(orig_audio.split(\"/\")[-1])[0]\n",
"with open(f\"{temp_folder}/{filename}.txt\", \"w\") as f:\n",
" f.write(orig_transcript)\n",
"# run MFA to get the alignment\n",
"align_temp = f\"{temp_folder}/mfa_alignments\"\n",
"os.makedirs(align_temp, exist_ok=True)\n",
"os.system(f\"mfa align -j 1 --output_format csv {temp_folder} english_us_arpa english_us_arpa {align_temp}\")\n",
"# if the above fails, it could be because the audio is too hard for the alignment model, increasing the beam size usually solves the issue\n",
"# os.system(f\"mfa align -j 1 --output_format csv {temp_folder} english_us_arpa english_us_arpa {align_temp} --beam 1000 --retry_beam 2000\")\n",
"audio_fn = f\"{temp_folder}/{filename}.wav\"\n",
"transcript_fn = f\"{temp_folder}/{filename}.txt\"\n",
"align_fn = f\"{align_temp}/{filename}.csv\""
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Dora directory: /tmp/audiocraft_pyp\n"
]
}
],
"source": [
"# take a look at demo/temp/mfa_alignment, decide which part of the audio to use as prompt\n",
"cut_off_sec = 3.01 # according to forced-alignment file, the word \"common\" stop as 3.01 sec\n",
"target_transcript = \"But when I had approached so near to them The common I cannot believe that the same model can also do text to speech synthesis as well!\"\n",
"info = torchaudio.info(audio_fn)\n",
"audio_dur = info.num_frames / info.sample_rate\n",
"\n",
"assert cut_off_sec < audio_dur, f\"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}\"\n",
"prompt_end_frame = int(cut_off_sec * info.sample_rate)\n",
"\n",
"\n",
"# # load model, tokenizer, and other necessary files\n",
"from models import voicecraft\n",
"ckpt_fn = \"/data/scratch/pyp/exp_pyp/VoiceCraft/gigaspeech/pretrained_830M/best_bundle.pth\"\n",
"encodec_fn = \"/data/scratch/pyp/exp_pyp/audiocraft/encodec/xps/6f79c6a8/checkpoint.th\"\n",
"ckpt = torch.load(ckpt_fn, map_location=\"cpu\")\n",
"model = voicecraft.VoiceCraft(ckpt[\"config\"])\n",
"model.load_state_dict(ckpt[\"model\"])\n",
"model.to(device)\n",
"model.eval()\n",
"\n",
"phn2num = ckpt['phn2num']\n",
"\n",
"text_tokenizer = TextTokenizer(backend=\"espeak\")\n",
"audio_tokenizer = AudioTokenizer(signature=encodec_fn) # will also put the neural codec model on gpu\n",
"\n",
"# run the model to get the output\n",
"decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, \"codec_audio_sr\": codec_audio_sr, \"codec_sr\": codec_sr, \"silence_tokens\": silence_tokens, \"sample_batch_size\": sample_batch_size}\n",
"from inference_tts_scale import inference_one_sample\n",
"concated_audio, gen_audio = inference_one_sample(model, ckpt[\"config\"], phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_transcript, device, decode_config, prompt_end_frame)\n",
" \n",
"# save segments for comparison\n",
"concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu()\n",
"# logging.info(f\"length of the resynthesize orig audio: {orig_audio.shape}\")\n",
"\n",
"# output_dir\n",
"output_dir = \"/home/pyp/VoiceCraft/demo/generated_tts\"\n",
"os.makedirs(output_dir, exist_ok=True)\n",
"\n",
"seg_save_fn_gen = f\"{output_dir}/{os.path.basename(audio_fn)[:-4]}_gen_seed{seed}.wav\"\n",
"seg_save_fn_concat = f\"{output_dir}/{os.path.basename(audio_fn)[:-4]}_concat_seed{seed}.wav\" \n",
"\n",
"torchaudio.save(seg_save_fn_gen, gen_audio, codec_audio_sr)\n",
"torchaudio.save(seg_save_fn_concat, concated_audio, codec_audio_sr)\n",
"\n",
"# if you get error importing T5 in transformers\n",
"# try \n",
"# pip uninstall Pillow\n",
"# pip install Pillow\n",
"# you are might get warnings like WARNING:phonemizer:words count mismatch on 300.0% of the lines (3/1), this can be safely ignored"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "voicecraft",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

190
inference_tts_scale.py Normal file
View File

@ -0,0 +1,190 @@
import argparse, pickle
import logging
import os, random
import numpy as np
import torch
import torchaudio
from data.tokenizer import (
AudioTokenizer,
TextTokenizer,
tokenize_audio,
tokenize_text
)
from models import voicecraft
import argparse, time, tqdm
# this script only works for the musicgen architecture
def get_args():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--manifest_fn", type=str, default="path/to/eval_metadata_file")
parser.add_argument("--audio_root", type=str, default="path/to/audio_folder")
parser.add_argument("--exp_dir", type=str, default="path/to/model_folder")
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--codec_audio_sr", type=int, default=16000, help='the sample rate of audio that the codec is trained for')
parser.add_argument("--codec_sr", type=int, default=50, help='the sample rate of the codec codes')
parser.add_argument("--top_k", type=int, default=0, help="sampling param")
parser.add_argument("--top_p", type=float, default=0.8, help="sampling param")
parser.add_argument("--temperature", type=float, default=1.0, help="sampling param")
parser.add_argument("--output_dir", type=str, default=None)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--signature", type=str, default=None, help="path to the encodec model")
parser.add_argument("--crop_concat", type=int, default=0)
parser.add_argument("--stop_repetition", type=int, default=-1, help="used for inference, when the number of consecutive repetition of a token is bigger than this, stop it")
parser.add_argument("--kvcache", type=int, default=1, help='if true, use kv cache, which is 4-8x faster than without')
parser.add_argument("--sample_batch_size", type=int, default=1, help="batch size for sampling, NOTE that it's not running inference for several samples, but duplicate one input sample batch_size times, and during inference, we only return the shortest generation")
parser.add_argument("--silence_tokens", type=str, default="[1388,1898,131]", help="note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default")
return parser.parse_args()
@torch.no_grad()
def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_text, device, decode_config, prompt_end_frame):
# phonemize
text_tokens = [phn2num[phn] for phn in
tokenize_text(
text_tokenizer, text=target_text.strip()
) if phn in phn2num
]
text_tokens = torch.LongTensor(text_tokens).unsqueeze(0)
text_tokens_lens = torch.LongTensor([text_tokens.shape[-1]])
# encode audio
encoded_frames = tokenize_audio(audio_tokenizer, audio_fn, offset=0, num_frames=prompt_end_frame)
original_audio = encoded_frames[0][0].transpose(2,1) # [1,T,K]
assert original_audio.ndim==3 and original_audio.shape[0] == 1 and original_audio.shape[2] == model_args.n_codebooks, original_audio.shape
logging.info(f"original audio length: {original_audio.shape[1]} codec frames, which is {original_audio.shape[1]/decode_config['codec_sr']:.2f} sec.")
# forward
stime = time.time()
if decode_config['sample_batch_size'] <= 1:
logging.info(f"running inference with batch size 1")
concat_frames, gen_frames = model.inference_tts(
text_tokens.to(device),
text_tokens_lens.to(device),
original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8]
top_k=decode_config['top_k'],
top_p=decode_config['top_p'],
temperature=decode_config['temperature'],
stop_repetition=decode_config['stop_repetition'],
kvcache=decode_config['kvcache'],
silence_tokens=eval(decode_config['silence_tokens']) if type(decode_config['silence_tokens'])==str else decode_config['silence_tokens']
) # output is [1,K,T]
else:
logging.info(f"running inference with batch size {decode_config['sample_batch_size']}, i.e. return the shortest among {decode_config['sample_batch_size']} generations.")
concat_frames, gen_frames = model.inference_tts_batch(
text_tokens.to(device),
text_tokens_lens.to(device),
original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8]
top_k=decode_config['top_k'],
top_p=decode_config['top_p'],
temperature=decode_config['temperature'],
stop_repetition=decode_config['stop_repetition'],
kvcache=decode_config['kvcache'],
batch_size = decode_config['sample_batch_size'],
silence_tokens=eval(decode_config['silence_tokens']) if type(decode_config['silence_tokens'])==str else decode_config['silence_tokens']
) # output is [1,K,T]
logging.info(f"inference on one sample take: {time.time() - stime:.4f} sec.")
logging.info(f"generated encoded_frames.shape: {gen_frames.shape}, which is {gen_frames.shape[-1]/decode_config['codec_sr']} sec.")
# for timestamp, codes in enumerate(gen_frames[0].transpose(1,0)):
# logging.info(f"{timestamp}: {codes.tolist()}")
# decode (both original and generated)
concat_sample = audio_tokenizer.decode(
[(concat_frames, None)] # [1,T,8] -> [1,8,T]
)
gen_sample = audio_tokenizer.decode(
[(gen_frames, None)]
)
# return
return concat_sample, gen_sample
def get_model(exp_dir, device=None):
with open(os.path.join(exp_dir, "args.pkl"), "rb") as f:
model_args = pickle.load(f)
logging.info("load model weights...")
model = voicecraft.VoiceCraft(model_args)
ckpt_fn = os.path.join(exp_dir, "best_bundle.pth")
ckpt = torch.load(ckpt_fn, map_location='cpu')['model']
phn2num = torch.load(ckpt_fn, map_location='cpu')['phn2num']
model.load_state_dict(ckpt)
del ckpt
logging.info("done loading weights...")
if device == None:
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda:0")
model.to(device)
model.eval()
return model, model_args, phn2num
if __name__ == "__main__":
def seed_everything(seed):
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
formatter = (
"%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
# args.device='cpu'
seed_everything(args.seed)
os.makedirs(args.output_dir, exist_ok=True)
# load model
with open(args.manifest_fn, "r") as rf:
manifest = [l.strip().split("\t") for l in rf.readlines()]
manifest = manifest[1:]
manifest = [[item[0], item[2], item[3], item[1], item[5]] for item in manifest]
stime = time.time()
logging.info(f"loading model from {args.exp_dir}")
model, model_args, phn2num = get_model(args.exp_dir)
logging.info(f"loading model done, took {time.time() - stime:.4f} sec")
# setup text and audio tokenizer
text_tokenizer = TextTokenizer(backend="espeak")
audio_tokenizer = AudioTokenizer(signature=args.signature) # will also put the neural codec model on gpu
audio_fns = []
texts = []
prompt_end_frames = []
new_audio_fns = []
text_to_syn = []
for item in manifest:
audio_fn = os.path.join(args.audio_root, item[0])
audio_fns.append(audio_fn)
temp = torchaudio.info(audio_fn)
prompt_end_frames.append(round(float(item[2])*temp.sample_rate))
texts.append(item[1])
new_audio_fns.append(item[-2])
all_text = item[1].split(" ")
start_ind = int(item[-1].split(",")[0])
text_to_syn.append(" ".join(all_text[start_ind:]))
for i, (audio_fn, text, prompt_end_frame, new_audio_fn, to_syn) in enumerate(tqdm.tqdm((zip(audio_fns, texts, prompt_end_frames, new_audio_fns, text_to_syn)))):
output_expected_sr = args.codec_audio_sr
concated_audio, gen_audio = inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, text, args.device, vars(args), prompt_end_frame)
# save segments for comparison
concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu()
if output_expected_sr != args.codec_audio_sr:
gen_audio = torchaudio.transforms.Resample(output_expected_sr, args.codec_audio_sr)(gen_audio)
concated_audio = torchaudio.transforms.Resample(output_expected_sr, args.codec_audio_sr)(concated_audio)
seg_save_fn_gen = f"{args.output_dir}/gen_{new_audio_fn[:-4]}_{i}_seed{args.seed}.wav"
seg_save_fn_concat = f"{args.output_dir}/concat_{new_audio_fn[:-4]}_{i}_seed{args.seed}.wav"
torchaudio.save(seg_save_fn_gen, gen_audio, args.codec_audio_sr)
torchaudio.save(seg_save_fn_concat, concated_audio, args.codec_audio_sr)

45
main.py Normal file
View File

@ -0,0 +1,45 @@
from pathlib import Path
import torch
import pickle
import argparse
import logging
import torch.distributed as dist
from config import MyParser
from steps import trainer
if __name__ == "__main__":
formatter = (
"%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
torch.cuda.empty_cache()
args = MyParser().parse_args()
logging.info(args)
exp_dir = Path(args.exp_dir)
exp_dir.mkdir(exist_ok=True, parents=True)
logging.info(f"exp_dir: {str(exp_dir)}")
if args.resume:
resume = args.resume
assert(bool(args.exp_dir))
with open("%s/args.pkl" % args.exp_dir, "rb") as f:
old_args = pickle.load(f)
new_args = vars(args)
old_args = vars(old_args)
for key in new_args:
if key not in old_args or old_args[key] != new_args[key]:
old_args[key] = new_args[key]
args = argparse.Namespace(**old_args)
args.resume = resume
else:
with open("%s/args.pkl" % args.exp_dir, "wb") as f:
pickle.dump(args, f)
dist.init_process_group(backend='nccl', init_method='env://')
rank = dist.get_rank()
world_size = dist.get_world_size()
torch.cuda.set_device(rank)
my_trainer = trainer.Trainer(args, world_size, rank)
my_trainer.train()

View File

@ -0,0 +1,538 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from collections import namedtuple
from dataclasses import dataclass
from functools import lru_cache
import logging
import typing as tp
from abc import ABC, abstractmethod
import torch
LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index)
PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates
@dataclass
class Pattern:
"""Base implementation of a pattern over a sequence with multiple codebooks.
The codebook pattern consists in a layout, defining for each sequence step
the list of coordinates of each codebook timestep in the resulting interleaved sequence.
The first item of the pattern is always an empty list in order to properly insert a special token
to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
and ``timesteps`` the number of timesteps corresponding to the original sequence.
The pattern provides convenient methods to build and revert interleaved sequences from it:
``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
to the interleaved sequence of shape [B, K, S] applying the pattern, with S being the batch size,
K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
for the output sequence. The unfilled positions are replaced with a special token and the built sequence
is returned along with a mask indicating valid tokens.
``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
to fill and specify invalid positions if needed.
See the dedicated methods for more details.
"""
# Pattern layout, for each sequence step, we have a list of coordinates
# corresponding to the original codebook timestep and position.
# The first list is always an empty list in order to properly insert
# a special token to start with.
layout: PatternLayout
timesteps: int
n_q: int
def __post_init__(self):
assert len(self.layout) > 0
assert self.layout[0] == []
self._validate_layout()
self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
# logging.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
def _validate_layout(self):
"""Runs checks on the layout to ensure a valid pattern is defined.
A pattern is considered invalid if:
- Multiple timesteps for a same codebook are defined in the same sequence step
- The timesteps for a given codebook are not in ascending order as we advance in the sequence
(this would mean that we have future timesteps before past timesteps).
"""
q_timesteps = {q: 0 for q in range(self.n_q)}
for s, seq_coords in enumerate(self.layout):
if len(seq_coords) > 0:
qs = set()
for coord in seq_coords:
qs.add(coord.q)
last_q_timestep = q_timesteps[coord.q]
assert coord.t >= last_q_timestep, \
f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
q_timesteps[coord.q] = coord.t
# each sequence step contains at max 1 coordinate per codebook
assert len(qs) == len(seq_coords), \
f"Multiple entries for a same codebook are found at step {s}"
@property
def num_sequence_steps(self):
return len(self.layout) - 1
@property
def max_delay(self):
max_t_in_seq_coords = 0
for seq_coords in self.layout[1:]:
for coords in seq_coords:
max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
return max_t_in_seq_coords - self.timesteps
@property
def valid_layout(self):
valid_step = len(self.layout) - self.max_delay
return self.layout[:valid_step]
def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
"""Get codebook coordinates in the layout that corresponds to the specified timestep t
and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
and the actual codebook coordinates.
"""
assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
if q is not None:
assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
coords = []
for s, seq_codes in enumerate(self.layout):
for code in seq_codes:
if code.t == t and (q is None or code.q == q):
coords.append((s, code))
return coords
def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
steps_with_timesteps = self.get_steps_with_timestep(t, q)
return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool,
device: tp.Union[torch.device, str] = 'cpu'):
"""Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
Args:
timesteps (int): Maximum number of timesteps steps to consider.
keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
device (Union[torch.device, str]): Device for created tensors.
Returns:
indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
"""
assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
# use the proper layout based on whether we limit ourselves to valid steps only or not,
# note that using the valid_layout will result in a truncated sequence up to the valid steps
ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
# single item indexing being super slow with pytorch vs. numpy, so we use numpy here
indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
# fill indexes with last sequence step value that will correspond to our special token
# the last value is n_q * timesteps as we have flattened z and append special token as the last token
# which will correspond to the index: n_q * timesteps
indexes[:] = n_q * timesteps
# iterate over the pattern and fill scattered indexes and mask
for s, sequence_coords in enumerate(ref_layout):
for coords in sequence_coords:
if coords.t < timesteps:
indexes[coords.q, s] = coords.t + coords.q * timesteps
mask[coords.q, s] = 1
indexes = torch.from_numpy(indexes).to(device)
mask = torch.from_numpy(mask).to(device)
return indexes, mask
def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
"""Build sequence corresponding to the pattern from the input tensor z.
The sequence is built using up to sequence_steps if specified, and non-pattern
coordinates are filled with the special token.
Args:
z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
Steps that are beyond valid steps will be replaced by the special_token in that case.
Returns:
values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
"""
B, K, T = z.shape
indexes, mask = self._build_pattern_sequence_scatter_indexes(
T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
)
z = z.view(B, -1)
# we append the special token as the last index of our flattened z tensor
z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
values = z[:, indexes.view(-1)]
values = values.view(B, K, indexes.shape[-1])
return values, indexes, mask
def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
keep_only_valid_steps: bool = False,
is_model_output: bool = False,
device: tp.Union[torch.device, str] = 'cpu'):
"""Builds scatter indexes required to retrieve the original multi-codebook sequence
from interleaving pattern.
Args:
sequence_steps (int): Sequence steps.
n_q (int): Number of codebooks.
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
Steps that are beyond valid steps will be replaced by the special_token in that case.
is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
device (Union[torch.device, str]): Device for created tensors.
Returns:
torch.Tensor: Indexes for reconstructing the output, of shape [K, T].
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
"""
ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
# TODO(jade): Do we want to further truncate to only valid timesteps here as well?
timesteps = self.timesteps
assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
assert sequence_steps <= len(ref_layout), \
f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
# ensure we take the appropriate indexes to keep the model output from the first special token as well
if is_model_output:
ref_layout = ref_layout[1:]
# single item indexing being super slow with pytorch vs. numpy, so we use numpy here
indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
# fill indexes with last sequence step value that will correspond to our special token
indexes[:] = n_q * sequence_steps
for s, sequence_codes in enumerate(ref_layout):
if s < sequence_steps:
for code in sequence_codes:
if code.t < timesteps:
indexes[code.q, code.t] = s + code.q * sequence_steps
mask[code.q, code.t] = 1
indexes = torch.from_numpy(indexes).to(device)
mask = torch.from_numpy(mask).to(device)
return indexes, mask
def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
"""Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
are filled with the special token.
Args:
s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
Returns:
values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
"""
B, K, S = s.shape
indexes, mask = self._build_reverted_sequence_scatter_indexes(
S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
)
s = s.view(B, -1)
# we append the special token as the last index of our flattened z tensor
s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
values = s[:, indexes.view(-1)]
values = values.view(B, K, indexes.shape[-1])
return values, indexes, mask
def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
"""Revert model logits obtained on a sequence built from the pattern
back to a tensor matching the original sequence.
This method is similar to ``revert_pattern_sequence`` with the following specificities:
1. It is designed to work with the extra cardinality dimension
2. We return the logits for the first sequence item that matches the special_token and
which matching target in the original sequence is the first item of the sequence,
while we skip the last logits as there is no matching target
"""
B, card, K, S = logits.shape
indexes, mask = self._build_reverted_sequence_scatter_indexes(
S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
)
logits = logits.reshape(B, card, -1)
# we append the special token as the last index of our flattened z tensor
logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S]
values = logits[:, :, indexes.view(-1)]
values = values.view(B, card, K, indexes.shape[-1])
return values, indexes, mask
class CodebooksPatternProvider(ABC):
"""Abstraction around providing pattern for interleaving codebooks.
The CodebooksPatternProvider abstraction allows to implement various strategies to
define interleaving pattern of sequences composed of multiple codebooks. For a given
number of codebooks `n_q`, the pattern provider can generate a specified pattern
corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
can be used to construct a new sequence from the original codes respecting the specified
pattern. The pattern is defined as a list of list of code coordinates, code coordinate
being a tuple with the original timestep and codebook to build the new sequence.
Note that all patterns must start with an empty list that is then used to insert a first
sequence step of special tokens in the newly generated sequence.
Args:
n_q (int): number of codebooks.
cached (bool): if True, patterns for a given length are cached. In general
that should be true for efficiency reason to avoid synchronization points.
"""
def __init__(self, n_q: int, cached: bool = True):
assert n_q > 0
self.n_q = n_q
self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore
@abstractmethod
def get_pattern(self, timesteps: int) -> Pattern:
"""Builds pattern with specific interleaving between codebooks.
Args:
timesteps (int): Total numer of timesteps.
"""
raise NotImplementedError()
class DelayedPatternProvider(CodebooksPatternProvider):
"""Provider for delayed pattern across delayed codebooks.
Codebooks are delayed in the sequence and sequence steps will contain codebooks
from different timesteps.
Example:
Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
[[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4]]
The resulting sequence obtained from the returned pattern is:
[[S, 1, 2, 3, 4],
[S, S, 1, 2, 3],
[S, S, S, 1, 2]]
(with S being a special token)
Args:
n_q (int): Number of codebooks.
delays (Optional[List[int]]): Delay for each of the codebooks.
If delays not defined, each codebook is delayed by 1 compared to the previous one.
flatten_first (int): Flatten the first N timesteps.
empty_initial (int): Prepend with N empty list of coordinates.
"""
def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
flatten_first: int = 0, empty_initial: int = 0):
super().__init__(n_q)
if delays is None:
delays = list(range(n_q))
self.delays = delays
self.flatten_first = flatten_first
self.empty_initial = empty_initial
assert len(self.delays) == self.n_q
assert sorted(self.delays) == self.delays
def get_pattern(self, timesteps: int) -> Pattern:
out: PatternLayout = [[]]
max_delay = max(self.delays)
if self.empty_initial:
out += [[] for _ in range(self.empty_initial)]
if self.flatten_first:
for t in range(min(timesteps, self.flatten_first)):
for q in range(self.n_q):
out.append([LayoutCoord(t, q)])
for t in range(self.flatten_first, timesteps + max_delay):
v = []
for q, delay in enumerate(self.delays):
t_for_q = t - delay
if t_for_q >= self.flatten_first:
v.append(LayoutCoord(t_for_q, q))
out.append(v)
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
class ParallelPatternProvider(DelayedPatternProvider):
"""Provider for parallel pattern across codebooks.
This pattern provider is a special case of the delayed pattern with actually no delay,
hence delays=repeat(0, n_q).
Args:
n_q (int): Number of codebooks.
"""
def __init__(self, n_q: int):
super().__init__(n_q, [0] * n_q)
class UnrolledPatternProvider(CodebooksPatternProvider):
"""Provider for unrolling codebooks pattern.
This pattern provider enables to represent the codebook flattened completely or only to some extend
while also specifying a given delay between the flattened codebooks representation, allowing to
unroll the codebooks in the sequence.
Example:
1. Flattening of the codebooks.
By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q),
taking n_q = 3 and timesteps = 4:
[[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4]]
will result into:
[[S, S, 1, S, S, 2, S, S, 3, S, S, 4],
[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
[1, S, S, 2, S, S, 3, S, S, 4, S, S]]
2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step
for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example
taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]:
[[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4]]
will result into:
[[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
[1, S, S, 2, S, S, 3, S, S, 4, S, S]]
3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks
allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the
same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1]
and delays = [0, 3, 3]:
[[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4]]
will result into:
[[S, S, S, 1, S, 2, S, 3, S, 4],
[S, S, S, 1, S, 2, S, 3, S, 4],
[1, 2, 3, S, 4, S, 5, S, 6, S]]
Args:
n_q (int): Number of codebooks.
flattening (Optional[List[int]]): Flattening schema over the codebooks. If not defined,
the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
have n_q extra steps for each timestep.
delays (Optional[List[int]]): Delay for each of the codebooks. If not defined,
no delay is added and therefore will default to [0] * ``n_q``.
Note that two codebooks that will be flattened to the same inner step
should have the same delay, otherwise the pattern is considered as invalid.
"""
FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay'])
def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None,
delays: tp.Optional[tp.List[int]] = None):
super().__init__(n_q)
if flattening is None:
flattening = list(range(n_q))
if delays is None:
delays = [0] * n_q
assert len(flattening) == n_q
assert len(delays) == n_q
assert sorted(flattening) == flattening
assert sorted(delays) == delays
self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
self.max_delay = max(delays)
def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]):
"""Build a flattened codebooks representation as a dictionary of inner step
and the actual codebook indices corresponding to the flattened codebook. For convenience, we
also store the delay associated to the flattened codebook to avoid maintaining an extra mapping.
"""
flattened_codebooks: dict = {}
for q, (inner_step, delay) in enumerate(zip(flattening, delays)):
if inner_step not in flattened_codebooks:
flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay)
else:
flat_codebook = flattened_codebooks[inner_step]
assert flat_codebook.delay == delay, (
"Delay and flattening between codebooks is inconsistent: ",
"two codebooks flattened to the same position should have the same delay."
)
flat_codebook.codebooks.append(q)
flattened_codebooks[inner_step] = flat_codebook
return flattened_codebooks
@property
def _num_inner_steps(self):
"""Number of inner steps to unroll between timesteps in order to flatten the codebooks.
"""
return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1
def num_virtual_steps(self, timesteps: int) -> int:
return timesteps * self._num_inner_steps + 1
def get_pattern(self, timesteps: int) -> Pattern:
"""Builds pattern for delay across codebooks.
Args:
timesteps (int): Total numer of timesteps.
"""
# the PatternLayout is built as a tuple of sequence position and list of coordinates
# so that it can be reordered properly given the required delay between codebooks of given timesteps
indexed_out: list = [(-1, [])]
max_timesteps = timesteps + self.max_delay
for t in range(max_timesteps):
# for each timestep, we unroll the flattened codebooks,
# emitting the sequence step with the corresponding delay
for step in range(self._num_inner_steps):
if step in self._flattened_codebooks:
# we have codebooks at this virtual step to emit
step_codebooks = self._flattened_codebooks[step]
t_for_q = t + step_codebooks.delay
coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
if t_for_q < max_timesteps and t < max_timesteps:
indexed_out.append((t_for_q, coords))
else:
# there is no codebook in this virtual step so we emit an empty list
indexed_out.append((t, []))
out = [coords for _, coords in sorted(indexed_out)]
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
class VALLEPattern(CodebooksPatternProvider):
"""Almost VALL-E style pattern. We futher allow some delays for the
codebooks other than the first one.
Args:
n_q (int): Number of codebooks.
delays (Optional[List[int]]): Delay for each of the codebooks.
If delays not defined, each codebook is delayed by 1 compared to the previous one.
"""
def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
super().__init__(n_q)
if delays is None:
delays = [0] * (n_q - 1)
self.delays = delays
assert len(self.delays) == self.n_q - 1
assert sorted(self.delays) == self.delays
def get_pattern(self, timesteps: int) -> Pattern:
out: PatternLayout = [[]]
for t in range(timesteps):
out.append([LayoutCoord(t, 0)])
max_delay = max(self.delays)
for t in range(timesteps + max_delay):
v = []
for q, delay in enumerate(self.delays):
t_for_q = t - delay
if t_for_q >= 0:
v.append(LayoutCoord(t_for_q, q + 1))
out.append(v)
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
class MusicLMPattern(CodebooksPatternProvider):
"""Almost MusicLM style pattern. This is equivalent to full flattening
but in a different order.
Args:
n_q (int): Number of codebooks.
group_by (int): Number of codebooks to group together.
"""
def __init__(self, n_q: int, group_by: int = 2):
super().__init__(n_q)
self.group_by = group_by
def get_pattern(self, timesteps: int) -> Pattern:
out: PatternLayout = [[]]
for offset in range(0, self.n_q, self.group_by):
for t in range(timesteps):
for q in range(offset, offset + self.group_by):
out.append([LayoutCoord(t, q)])
return Pattern(out, n_q=self.n_q, timesteps=timesteps)

View File

View File

@ -0,0 +1,653 @@
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py, modified by Puyuan Peng, 2024
from typing import Optional, Tuple
import torch
from torch import Tensor
from torch.nn import Linear, Module
from torch.nn import functional as F
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.parameter import Parameter
import logging
from typing import Callable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from torch.types import _dtype as DType
else:
# The JIT doesn't understand Union, nor torch.dtype here
DType = int
def _canonical_mask(
mask: Optional[Tensor],
mask_name: str,
other_type: Optional[DType],
other_name: str,
target_type: DType,
check_other: bool = True,
) -> Optional[Tensor]:
if mask is not None:
_mask_dtype = mask.dtype
_mask_is_float = torch.is_floating_point(mask)
if _mask_dtype != torch.bool and not _mask_is_float:
raise AssertionError(
f"only bool and floating types of {mask_name} are supported")
if check_other and other_type is not None:
if _mask_dtype != other_type:
warnings.warn(
f"Support for mismatched {mask_name} and {other_name} "
"is deprecated. Use same type for both instead."
)
if not _mask_is_float:
mask = (
torch.zeros_like(mask, dtype=target_type)
.masked_fill_(mask, float("-inf"))
)
return mask
def _in_projection_packed(
q: Tensor,
k: Tensor,
v: Tensor,
w: Tensor,
b: Optional[Tensor] = None,
) -> List[Tensor]:
r"""
Performs the in-projection step of the attention operation, using packed weights.
Output is a triple containing projection tensors for query, key and value.
Args:
q, k, v: query, key and value tensors to be projected. For self-attention,
these are typically the same tensor; for encoder-decoder attention,
k and v are typically the same tensor. (We take advantage of these
identities for performance if they are present.) Regardless, q, k and v
must share a common embedding dimension; otherwise their shapes may vary.
w: projection weights for q, k and v, packed into a single tensor. Weights
are packed along dimension 0, in q, k, v order.
b: optional projection biases for q, k and v, packed into a single tensor
in q, k, v order.
Shape:
Inputs:
- q: :math:`(..., E)` where E is the embedding dimension
- k: :math:`(..., E)` where E is the embedding dimension
- v: :math:`(..., E)` where E is the embedding dimension
- w: :math:`(E * 3, E)` where E is the embedding dimension
- b: :math:`E * 3` where E is the embedding dimension
Output:
- in output list :math:`[q', k', v']`, each output tensor will have the
same shape as the corresponding input tensor.
"""
E = q.size(-1)
if k is v:
if q is k:
# self-attention
proj = F.linear(q, w, b)
# reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
return proj[0], proj[1], proj[2]
else:
# encoder-decoder attention
w_q, w_kv = w.split([E, E * 2])
if b is None:
b_q = b_kv = None
else:
b_q, b_kv = b.split([E, E * 2])
q_proj = F.linear(q, w_q, b_q)
kv_proj = F.linear(k, w_kv, b_kv)
# reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk()
kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
return (q_proj, kv_proj[0], kv_proj[1])
else:
w_q, w_k, w_v = w.chunk(3)
if b is None:
b_q = b_k = b_v = None
else:
b_q, b_k, b_v = b.chunk(3)
return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]:
if input is None:
return None
elif isinstance(input, torch.Tensor):
return input.dtype
raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor")
class MultiheadAttention(Module):
r"""Allows the model to jointly attend to information
from different representation subspaces as described in the paper:
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
Multi-Head Attention is defined as:
.. math::
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
``forward()`` will use a special optimized implementation if all of the following
conditions are met:
- self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
restriction will be loosened in the future.)
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
- training is disabled (using ``.eval()``)
- dropout is 0
- ``add_bias_kv`` is ``False``
- ``add_zero_attn`` is ``False``
- ``batch_first`` is ``True`` and the input is batched
- ``kdim`` and ``vdim`` are equal to ``embed_dim``
- at most one of ``key_padding_mask`` or ``attn_mask`` is passed
- if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
nor ``attn_mask`` is passed
If the optimized implementation is in use, a
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
``query``/``key``/``value`` to represent padding more efficiently than using a
padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
will be returned, and an additional speedup proportional to the fraction of the input
that is padding can be expected.
Args:
embed_dim: Total dimension of the model.
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
bias: If specified, adds bias to input / output projection layers. Default: ``True``.
add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
Default: ``False``.
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
batch_first: If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
Examples::
>>> # xdoctest: +SKIP
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
"""
__constants__ = ["batch_first"]
bias_k: Optional[torch.Tensor]
bias_v: Optional[torch.Tensor]
def __init__(
self,
embed_dim,
num_heads,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None,
vdim=None,
batch_first=False,
linear1_cls=Linear,
linear2_cls=Linear,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = (
self.kdim == embed_dim and self.vdim == embed_dim
)
self.num_heads = num_heads
self.dropout = dropout
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
if add_bias_kv:
self.bias_k = Parameter(
torch.empty((1, 1, embed_dim), **factory_kwargs)
)
self.bias_v = Parameter(
torch.empty((1, 1, embed_dim), **factory_kwargs)
)
else:
self.bias_k = self.bias_v = None
if linear1_cls == Linear:
if not self._qkv_same_embed_dim:
self.q_proj_weight = Parameter(
torch.empty((embed_dim, embed_dim), **factory_kwargs)
)
self.k_proj_weight = Parameter(
torch.empty((embed_dim, self.kdim), **factory_kwargs)
)
self.v_proj_weight = Parameter(
torch.empty((embed_dim, self.vdim), **factory_kwargs)
)
self.register_parameter("in_proj_weight", None)
else:
# go down this route with voicecraft
self.in_proj_weight = Parameter(
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
)
self.register_parameter("q_proj_weight", None)
self.register_parameter("k_proj_weight", None)
self.register_parameter("v_proj_weight", None)
if bias: # True by default
self.in_proj_bias = Parameter(
torch.empty(3 * embed_dim, **factory_kwargs)
)
else:
self.register_parameter("in_proj_bias", None)
self.out_proj = NonDynamicallyQuantizableLinear(
embed_dim, embed_dim, bias=bias, **factory_kwargs
)
self._reset_parameters()
else:
if not self._qkv_same_embed_dim:
raise NotImplementedError
else:
self.in_proj_linear = linear1_cls(
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
)
self.in_proj_weight = self.in_proj_linear.weight
self.register_parameter("q_proj_weight", None)
self.register_parameter("k_proj_weight", None)
self.register_parameter("v_proj_weight", None)
if bias:
self.in_proj_bias = self.in_proj_linear.bias
else:
self.register_parameter("in_proj_bias", None)
self.out_proj = linear2_cls(
embed_dim, embed_dim, bias=bias, **factory_kwargs
)
if self.bias_k is not None:
xavier_normal_(self.bias_k)
if self.bias_v is not None:
xavier_normal_(self.bias_v)
self.add_zero_attn = add_zero_attn
def _reset_parameters(self):
if self._qkv_same_embed_dim:
xavier_uniform_(self.in_proj_weight)
else:
xavier_uniform_(self.q_proj_weight)
xavier_uniform_(self.k_proj_weight)
xavier_uniform_(self.v_proj_weight)
if self.in_proj_bias is not None:
constant_(self.in_proj_bias, 0.0)
constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None:
xavier_normal_(self.bias_k)
if self.bias_v is not None:
xavier_normal_(self.bias_v)
def __setstate__(self, state):
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
if "_qkv_same_embed_dim" not in state:
state["_qkv_same_embed_dim"] = True
super(MultiheadAttention, self).__setstate__(state)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True,
past: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
:math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
Queries are compared against key-value pairs to produce the output.
See "Attention Is All You Need" for more details.
key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
:math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
See "Attention Is All You Need" for more details.
value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
See "Attention Is All You Need" for more details.
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
Binary and byte masks are supported.
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
Default: ``True``.
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
the attention weight.
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
Outputs:
- **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
:math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
embedding dimension ``embed_dim``.
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
.. note::
`batch_first` argument is ignored for unbatched inputs.
"""
is_batched = query.dim() == 3
if key_padding_mask is not None:
_kpm_dtype = key_padding_mask.dtype
if _kpm_dtype != torch.bool and not torch.is_floating_point(
key_padding_mask
):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported"
)
why_not_fast_path = ""
if not is_batched:
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
elif query is not key or key is not value:
# When lifting this restriction, don't forget to either
# enforce that the dtypes all match or test cases where
# they don't!
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
elif (
self.in_proj_bias is not None
and query.dtype != self.in_proj_bias.dtype
):
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
elif (
self.in_proj_weight is not None
and query.dtype != self.in_proj_weight.dtype
):
# this case will fail anyway, but at least they'll get a useful error message.
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
elif self.training:
why_not_fast_path = "training is enabled"
elif not self.batch_first:
why_not_fast_path = "batch_first was not True"
elif self.bias_k is not None:
why_not_fast_path = "self.bias_k was not None"
elif self.bias_v is not None:
why_not_fast_path = "self.bias_v was not None"
elif self.dropout:
why_not_fast_path = f"dropout was {self.dropout}, required zero"
elif self.add_zero_attn:
why_not_fast_path = "add_zero_attn was enabled"
elif not self._qkv_same_embed_dim:
why_not_fast_path = "_qkv_same_embed_dim was not True"
elif attn_mask is not None:
why_not_fast_path = "attn_mask was not None"
elif query.is_nested and key_padding_mask is not None:
why_not_fast_path = (
"key_padding_mask is not supported with NestedTensor input"
)
elif self.num_heads % 2 == 1:
why_not_fast_path = "num_heads is odd"
elif torch.is_autocast_enabled():
why_not_fast_path = "autocast is enabled"
if not why_not_fast_path:
tensor_args = (
query,
key,
value,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj.weight,
self.out_proj.bias,
)
# We have to use list comprehensions below because TorchScript does not support
# generator expressions.
if torch.overrides.has_torch_function(tensor_args):
why_not_fast_path = "some Tensor argument has_torch_function"
elif not all(
[
(x is None or x.is_cuda or "cpu" in str(x.device))
for x in tensor_args
]
):
why_not_fast_path = (
"some Tensor argument is neither CUDA nor CPU"
)
elif torch.is_grad_enabled() and any(
[x is not None and x.requires_grad for x in tensor_args]
):
why_not_fast_path = (
"grad is enabled and at least one of query or the "
"input/output projection weights or biases requires_grad"
)
if not why_not_fast_path:
return torch._native_multi_head_attention(
query,
key,
value,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj.weight,
self.out_proj.bias,
key_padding_mask
if key_padding_mask is not None
else attn_mask,
need_weights,
average_attn_weights,
1
if key_padding_mask is not None
else 0
if attn_mask is not None
else None,
)
any_nested = query.is_nested or key.is_nested or value.is_nested
assert not any_nested, (
"MultiheadAttention does not support NestedTensor outside of its fast path. "
+ f"The fast path was not hit because {why_not_fast_path}"
)
if self.batch_first and is_batched:
# make sure that the transpose op does not affect the "is" property
if key is value:
if query is key:
query = key = value = query.transpose(1, 0)
else:
query, key = [x.transpose(1, 0) for x in (query, key)]
value = key
else:
query, key, value = [
x.transpose(1, 0) for x in (query, key, value)
]
if not self._qkv_same_embed_dim:
attn_output, attn_output_weights = F.multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight,
k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight,
average_attn_weights=average_attn_weights,
)
else:
# re-write the self.attention here, to get k, v cache
tgt_len, bsz, embed_dim = query.shape
src_len, _, _ = key.shape
num_heads = self.num_heads
key_padding_mask = _canonical_mask(
mask=key_padding_mask,
mask_name="key_padding_mask",
other_type=_none_or_dtype(attn_mask),
other_name="attn_mask",
target_type=query.dtype
)
attn_mask = _canonical_mask(
mask=attn_mask,
mask_name="attn_mask",
other_type=None,
other_name="",
target_type=query.dtype,
check_other=False,
)
head_dim = self.embed_dim // self.num_heads
assert head_dim * self.num_heads == self.embed_dim, f"embed_dim {self.embed_dim} not divisible by num_heads {self.num_heads}"
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias)
# k_present, v_present = k, v
#
# reshape q, k, v for multihead attention and make em batch first
#
q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) # (bsz * num_heads, src_len, head_dim)
src_len = k.size(1)
if past is not None and past.ndim > 2:
expected_src_len = src_len + past[0].shape[-2]
else:
expected_src_len = src_len
# ensure attn_mask's dim is 3
if attn_mask.dim() == 2:
correct_2d_size = (tgt_len, expected_src_len)
if attn_mask.shape != correct_2d_size:
raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
attn_mask = attn_mask.unsqueeze(0)
elif attn_mask.dim() == 3:
correct_3d_size = (bsz * num_heads, tgt_len, expected_src_len)
if attn_mask.shape != correct_3d_size:
raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
else:
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
if key_padding_mask is not None:
assert key_padding_mask.shape == (bsz, expected_src_len), \
f"expecting key_padding_mask shape of {(bsz, expected_src_len)}, but got {key_padding_mask.shape}"
key_padding_mask = key_padding_mask.view(bsz, 1, 1, expected_src_len). \
expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, expected_src_len)
if attn_mask is None:
attn_mask = key_padding_mask
else:
attn_mask = attn_mask + key_padding_mask
if not self.training:
dropout_p = 0.0
else:
dropout_p = self.dropout
if need_weights:
raise NotImplementedError("need_weights not implemented for voicecraft")
# B, Nt, E = q.shape
# q_scaled = q / math.sqrt(E)
# assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
# if attn_mask is not None:
# attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
# else:
# attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
# attn_output_weights = softmax(attn_output_weights, dim=-1)
# if dropout_p > 0.0:
# attn_output_weights = dropout(attn_output_weights, p=dropout_p)
# attn_output = torch.bmm(attn_output_weights, v)
# attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
# attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
# attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
# # optionally average attention weights over heads
# attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
# if average_attn_weights:
# attn_output_weights = attn_output_weights.mean(dim=1)
# if not is_batched:
# # squeeze the output if input was unbatched
# attn_output = attn_output.squeeze(1)
# attn_output_weights = attn_output_weights.squeeze(0)
# return attn_output, attn_output_weights
else:
# attn_mask can be either (L,S) or (N*num_heads, L, S)
# if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
# in order to match the input for SDPA of (N, num_heads, L, S)
if attn_mask is not None:
if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
attn_mask = attn_mask.unsqueeze(0)
else:
attn_mask = attn_mask.view(bsz, num_heads, -1, expected_src_len)
q = q.view(bsz, num_heads, tgt_len, head_dim)
k = k.view(bsz, num_heads, src_len, head_dim)
v = v.view(bsz, num_heads, src_len, head_dim)
# logging.info(f"shape of past: {past.shape}")
if past is not None:
present = torch.stack([k, v], dim=0) # (2, bsz, num_heads, src_len, head_dim)
if past.ndim > 2: # this means we use kvcache, otherwise we just pass in a placeholder, but not actually using kvcache
pk, pv = past
k = torch.cat([pk, k], dim=-2)
v = torch.cat([pv, v], dim=-2)
else:
present = None
attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal=False)
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
attn_output = F.linear(attn_output, self.out_proj.weight, self.out_proj.bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
if not is_batched:
# squeeze the output if input was unbatched
attn_output = attn_output.squeeze(1)
# if self.training:
# return attn_output, None
# else:
# return (attn_output, present), None
# harded coded, the code do not support returning attn weigths yet
attn_output_weights=None
if self.batch_first and is_batched:
return attn_output.transpose(1, 0), present
else:
return attn_output, present

View File

@ -0,0 +1,98 @@
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
# Copyright 2023 (authors: Feiteng Li)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
import torch.nn as nn
class TokenEmbedding(nn.Module):
def __init__(
self,
dim_model: int,
vocab_size: int,
dropout: float = 0.0,
):
super().__init__()
self.vocab_size = vocab_size
self.dim_model = dim_model
self.dropout = torch.nn.Dropout(p=dropout)
self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
@property
def weight(self) -> torch.Tensor:
return self.word_embeddings.weight
def embedding(self, index: int) -> torch.Tensor:
return self.word_embeddings.weight[index : index + 1]
def forward(self, x: torch.Tensor):
X = self.word_embeddings(x)
X = self.dropout(X)
return X
class SinePositionalEmbedding(nn.Module):
def __init__(
self,
dim_model: int,
dropout: float = 0.0,
scale: bool = False,
alpha: bool = False,
):
super().__init__()
self.dim_model = dim_model
self.x_scale = math.sqrt(dim_model) if scale else 1.0
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
self.dropout = torch.nn.Dropout(p=dropout)
self.reverse = False
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, 4000))
def extend_pe(self, x):
"""Reset the positional encodings."""
if self.pe is not None:
if self.pe.size(1) >= x.size(1):
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
pe = torch.zeros(x.size(1), self.dim_model)
if self.reverse:
position = torch.arange(
x.size(1) - 1, -1, -1.0, dtype=torch.float32
).unsqueeze(1)
else:
position = torch.arange(
0, x.size(1), dtype=torch.float32
).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.dim_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.dim_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
def forward(self, x: torch.Tensor) -> torch.Tensor:
self.extend_pe(x)
output = x.unsqueeze(-1) if x.ndim == 2 else x
output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
return self.dropout(output)

View File

@ -0,0 +1,63 @@
import torch
import torch.nn.functional as F
def top_k_top_p_filtering(
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Make sure we keep at least min_tokens_to_keep per batch example in the output
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
if top_k > 0:
top_k = min(
max(top_k, min_tokens_to_keep), logits.size(-1)
) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(
F.softmax(sorted_logits, dim=-1), dim=-1
)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
..., :-1
].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
# temperature: (`optional`) float
# The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
# top_k: (`optional`) int
# The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
# top_p: (`optional`) float
# The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
# Temperature (higher temperature => more likely to sample low probability tokens)
if temperature != 1.0:
logits = logits / temperature
# Top-p/top-k filtering
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
# Sample
token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
return token

1406
models/modules/scaling.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,698 @@
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py, modified by Puyuan Peng 2024
import copy
import numbers
from functools import partial
from typing import Any, Callable, List, Optional, Tuple, Union
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from .activation import MultiheadAttention
from .scaling import ActivationBalancer, BalancedDoubleSwish
from .scaling import BasicNorm as _BasicNorm
_shape_t = Union[int, List[int], torch.Size]
class LayerNorm(nn.Module):
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool
def __init__(
self,
normalized_shape: _shape_t,
eps: float = 1e-5,
elementwise_affine: bool = True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
self.bias = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self) -> None:
if self.elementwise_affine:
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
return (
F.layer_norm(
input,
self.normalized_shape,
self.weight,
self.bias,
self.eps,
),
embedding,
)
assert embedding is None
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps
)
def extra_repr(self) -> str:
return (
"{normalized_shape}, eps={eps}, "
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
)
class AdaptiveLayerNorm(nn.Module):
r"""Adaptive Layer Normalization"""
def __init__(self, d_model, norm) -> None:
super(AdaptiveLayerNorm, self).__init__()
self.project_layer = nn.Linear(d_model, 2 * d_model)
self.norm = norm
self.d_model = d_model
self.eps = self.norm.eps
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
weight, bias = torch.split(
self.project_layer(embedding),
split_size_or_sections=self.d_model,
dim=-1,
)
return (weight * self.norm(input) + bias, embedding)
weight, bias = torch.split(
self.project_layer(embedding),
split_size_or_sections=self.d_model,
dim=-1,
)
return weight * self.norm(input) + bias
class BasicNorm(_BasicNorm):
def __init__(
self,
d_model: int,
eps: float = 1e-5,
device=None,
dtype=None,
):
super(BasicNorm, self).__init__(d_model, eps=eps)
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
return (
super(BasicNorm, self).forward(input),
embedding,
)
assert embedding is None
return super(BasicNorm, self).forward(input)
class BalancedBasicNorm(nn.Module):
def __init__(
self,
d_model: int,
eps: float = 1e-5,
device=None,
dtype=None,
):
super(BalancedBasicNorm, self).__init__()
self.balancer = ActivationBalancer(
d_model,
channel_dim=-1,
min_positive=0.45,
max_positive=0.55,
max_abs=6.0,
)
self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype)
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
return self.norm((self.balancer(input), embedding))
assert embedding is None
return self.norm(self.balancer(input))
class IdentityNorm(nn.Module):
def __init__(
self,
d_model: int,
eps: float = 1e-5,
device=None,
dtype=None,
) -> None:
super(IdentityNorm, self).__init__()
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple):
return input
assert embedding is None
return input
class TransformerEncoderLayer(nn.Module):
__constants__ = ["batch_first", "norm_first"]
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
batch_first: bool = False,
norm_first: bool = False,
device=None,
dtype=None,
linear1_self_attention_cls: nn.Module = nn.Linear,
linear2_self_attention_cls: nn.Module = nn.Linear,
linear1_feedforward_cls: nn.Module = nn.Linear,
linear2_feedforward_cls: nn.Module = nn.Linear,
layer_norm_cls: nn.Module = LayerNorm,
layer_norm_eps: float = 1e-5,
adaptive_layer_norm=False,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(TransformerEncoderLayer, self).__init__()
self.self_attn = MultiheadAttention(
d_model,
nhead,
dropout=dropout,
batch_first=batch_first,
linear1_cls=linear1_self_attention_cls,
linear2_cls=linear2_self_attention_cls,
**factory_kwargs,
)
# Implementation of Feedforward model
self.linear1 = linear1_feedforward_cls(
d_model, dim_feedforward, **factory_kwargs
)
self.dropout = nn.Dropout(dropout)
self.linear2 = linear2_feedforward_cls(
dim_feedforward, d_model, **factory_kwargs
)
self.norm_first = norm_first
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
# Legacy string support for activation function.
if isinstance(activation, str):
activation = _get_activation_fn(activation)
elif isinstance(activation, partial):
activation = activation(d_model)
elif activation == BalancedDoubleSwish:
activation = BalancedDoubleSwish(d_model)
# # We can't test self.activation in forward() in TorchScript,
# # so stash some information about it instead.
# if activation is F.relu or isinstance(activation, torch.nn.ReLU):
# self.activation_relu_or_gelu = 1
# elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
# self.activation_relu_or_gelu = 2
# else:
# self.activation_relu_or_gelu = 0
self.activation = activation
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
if layer_norm_cls == IdentityNorm:
norm2 = BalancedBasicNorm(
d_model, eps=layer_norm_eps, **factory_kwargs
)
else:
norm2 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
if adaptive_layer_norm:
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
self.norm2 = AdaptiveLayerNorm(d_model, norm2)
else:
self.norm1 = norm1
self.norm2 = norm2
def __setstate__(self, state):
super(TransformerEncoderLayer, self).__setstate__(state)
if not hasattr(self, "activation"):
self.activation = F.relu
def forward(
self,
src: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
need_weights: Optional[bool] = False,
past: Optional[Tensor] = None,
) -> Tensor:
r"""Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
see the docs in Transformer class.
"""
x, stage_embedding = src, None
is_src_tuple = False
if isinstance(src, tuple):
x, stage_embedding = src
is_src_tuple = True
if src_key_padding_mask is not None:
_skpm_dtype = src_key_padding_mask.dtype
if _skpm_dtype != torch.bool and not torch.is_floating_point(
src_key_padding_mask
):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported"
)
if need_weights:
if self.norm_first:
out, attn = self._sa_block_attn(
self.norm1(x, stage_embedding),
src_mask,
src_key_padding_mask,
past
)
out, present = out # present is the kvcache of the present timestep
x = x + out
x = x + self._ff_block(self.norm2(x, stage_embedding))
else:
out, attn = self._sa_block_attn(x, src_mask, src_key_padding_mask, past)
out, present = out # present is the kvcache of the present timestep
x = self.norm1(
x + out,
stage_embedding,
)
x = self.norm2(x + self._ff_block(x), stage_embedding)
assert not is_src_tuple
# return (x, stage_embedding)
return (x, attn)
else:
if self.norm_first:
out = self._sa_block(
self.norm1(x, stage_embedding),
src_mask,
src_key_padding_mask, past
)
out, present = out # present is the kvcache of the present timestep
x = x + out
x = x + self._ff_block(self.norm2(x, stage_embedding))
else:
out = self._sa_block(x, src_mask, src_key_padding_mask)
out, present = out # present is the kvcache of the present timestep
x = self.norm1(
x + out,
stage_embedding, past
)
x = self.norm2(x + self._ff_block(x), stage_embedding)
if is_src_tuple:
x = (x, stage_embedding)
if present != None:
x = [x, present]
return x
# self-attention block
def _sa_block(
self,
x: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
past: Optional[Tensor] = None,
) -> Tensor:
x = self.self_attn(
x,
x,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
past=past
)
x, present = x
return self.dropout1(x), present
# self-attention block, also return attention weights
def _sa_block_attn(
self,
x: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
past: Optional[Tensor] = None,
) -> Tensor:
x, attn = self.self_attn(
x,
x,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=True,
past=past
)
x, present = x
return (self.dropout1(x), present), attn
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout2(x)
class TransformerEncoder(nn.Module):
r"""TransformerEncoder is a stack of N encoder layers. Users can build the
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
Args:
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required).
norm: the layer normalization component (optional).
enable_nested_tensor: if True, input will automatically convert to nested tensor
(and convert back on output). This will improve the overall performance of
TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
Examples::
>>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
>>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
>>> src = torch.rand(10, 32, 512)
>>> out = transformer_encoder(src)
"""
__constants__ = ["norm"]
def __init__(self, encoder_layer, num_layers, norm=None):
super(TransformerEncoder, self).__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(
self,
src: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
return_layer_states: bool = False,
need_weights:Optional[bool] = False,
past: Optional[Tensor] = None,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
return_layer_states: return layers' state (optional).
Shape:
see the docs in Transformer class.
"""
if return_layer_states:
assert not need_weights
layer_states = [] # layers' output
output = src
for mod in self.layers:
output = mod(
output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
past=past
)
layer_states.append(output[0])
if self.norm is not None:
output = self.norm(output)
return layer_states, output
if need_weights:
assert not return_layer_states
layer_attn = [] # layers' output
output = src
for mod in self.layers:
output = mod(
output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
need_weights=True,
past=past
)
layer_attn.append(output[1])
if self.norm is not None:
output = self.norm(output)
return layer_attn, output
output = src
all_present = []
for n_layer, mod in enumerate(self.layers):
output = mod(
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, past=None if past is None else past[n_layer]
)
if isinstance(output, list):
output, present = output
all_present.append(present)
if self.norm is not None:
output = self.norm(output)
if all_present != []:
all_present = torch.stack(all_present, dim=0) # (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
output = [output, all_present]
return output
class TransformerDecoderLayer(nn.Module):
__constants__ = ["batch_first", "norm_first"]
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
linear1_self_attention_cls: nn.Module = nn.Linear,
linear2_self_attention_cls: nn.Module = nn.Linear,
linear1_feedforward_cls: nn.Module = nn.Linear,
linear2_feedforward_cls: nn.Module = nn.Linear,
batch_first: bool = False,
norm_first: bool = False,
device=None,
dtype=None,
layer_norm_cls: nn.Module = LayerNorm,
layer_norm_eps: float = 1e-5,
adaptive_layer_norm=False,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(TransformerDecoderLayer, self).__init__()
self.self_attn = MultiheadAttention(
d_model,
nhead,
dropout=dropout,
batch_first=batch_first,
linear1_cls=linear1_self_attention_cls,
linear2_cls=linear2_self_attention_cls,
**factory_kwargs,
)
self.multihead_attn = MultiheadAttention(
d_model,
nhead,
dropout=dropout,
batch_first=batch_first,
linear1_cls=linear1_self_attention_cls,
linear2_cls=linear2_self_attention_cls,
**factory_kwargs,
)
# Implementation of Feedforward model
self.linear1 = linear1_feedforward_cls(
d_model, dim_feedforward, **factory_kwargs
)
self.dropout = nn.Dropout(dropout)
self.linear2 = linear2_feedforward_cls(
dim_feedforward, d_model, **factory_kwargs
)
self.norm_first = norm_first
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
# Legacy string support for activation function.
if isinstance(activation, str):
self.activation = _get_activation_fn(activation)
elif isinstance(activation, partial):
self.activation = activation(d_model)
elif activation == BalancedDoubleSwish:
self.activation = BalancedDoubleSwish(d_model)
else:
self.activation = activation
if adaptive_layer_norm:
norm1 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
norm2 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
norm3 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
self.norm2 = AdaptiveLayerNorm(d_model, norm2)
self.norm3 = AdaptiveLayerNorm(d_model, norm3)
else:
self.norm1 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
self.norm2 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
if layer_norm_cls == IdentityNorm:
self.norm3 = BalancedBasicNorm(
d_model, eps=layer_norm_eps, **factory_kwargs
)
else:
self.norm3 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
def forward(
self,
tgt: Tensor,
memory: Tensor,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
r"""Pass the inputs (and mask) through the decoder layer.
Args:
tgt: the sequence to the decoder layer (required).
memory: the sequence from the last layer of the encoder (required).
tgt_mask: the mask for the tgt sequence (optional).
memory_mask: the mask for the memory sequence (optional).
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
memory_key_padding_mask: the mask for the memory keys per batch (optional).
Shape:
see the docs in Transformer class.
"""
tgt_is_tuple = False
if isinstance(tgt, tuple):
x, stage_embedding = tgt
tgt_is_tuple = True
else:
x, stage_embedding = tgt, None
if self.norm_first:
x = x + self._sa_block(
self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask
)
x = x + self._mha_block(
self.norm2(x, stage_embedding),
memory,
memory_mask,
memory_key_padding_mask,
)
x = x + self._ff_block(self.norm3(x, stage_embedding))
else:
x = self.norm1(
x + self._sa_block(x, tgt_mask, tgt_key_padding_mask),
stage_embedding,
)
x = self.norm2(
x
+ self._mha_block(
x, memory, memory_mask, memory_key_padding_mask
),
stage_embedding,
)
x = self.norm3(x + self._ff_block(x), stage_embedding)
if tgt_is_tuple:
return (x, stage_embedding)
return x
# self-attention block
def _sa_block(
self,
x: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
) -> Tensor:
x = self.self_attn(
x,
x,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
)[0]
return self.dropout1(x)
# multihead attention block
def _mha_block(
self,
x: Tensor,
mem: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
) -> Tensor:
x = self.multihead_attn(
x,
mem,
mem,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
)[0]
return self.dropout2(x)
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout3(x)
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
if activation == "relu":
return F.relu
elif activation == "gelu":
return F.gelu
raise RuntimeError(
"activation should be relu/gelu, not {}".format(activation)
)

37
models/modules/utils.py Normal file
View File

@ -0,0 +1,37 @@
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py, modified by Puyuan Peng
import torch
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
"""
Args:
lengths:
A 1-D tensor containing sentence lengths.
max_len:
The length of masks.
Returns:
Return a 2-D bool tensor, where masked positions
are filled with `True` and non-masked positions are
filled with `False`.
>>> lengths = torch.tensor([1, 3, 2, 5])
>>> make_pad_mask(lengths)
tensor([[False, True, True, True, True],
[False, False, False, True, True],
[False, False, True, True, True],
[False, False, False, False, False]])
"""
assert lengths.ndim == 1, lengths.ndim
max_len = max(max_len, lengths.max())
n = lengths.size(0)
seq_range = torch.arange(0, max_len, device=lengths.device)
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
return expaned_lengths >= lengths.unsqueeze(-1)
def generate_partial_autoregressive_mask(sz, start, end):
mask = torch.zeros(sz, sz).bool()
mask[start:end, start:end] = torch.triu(torch.ones(end-start, end-start,dtype=torch.bool), diagonal=1)
mask[:start, start:end] = True
mask[end:, start:end] = True
return mask

1402
models/voicecraft.py Normal file

File diff suppressed because it is too large Load Diff

0
steps/__init__.py Normal file
View File

1123
steps/optim.py Normal file

File diff suppressed because it is too large Load Diff

467
steps/trainer.py Normal file
View File

@ -0,0 +1,467 @@
import time
import os, random
import torch
import math, pickle
from tqdm import tqdm
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
import torch.nn as nn
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from torch.utils.data.distributed import DistributedSampler
import logging
from data import gigaspeech
from models import voicecraft
from .trainer_utils import DistributedDynamicBatchSampler, StatefulDistributedSampler, AverageMeter, print_model_info
from .optim import ScaledAdam, Eden
class Trainer:
def __init__(self, args, world_size, rank):
self.start_time = time.time()
self.args = args
self.world_size, self.rank = world_size, rank
self.device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
if self.rank == 0:
self.writer = SummaryWriter(args.exp_dir)
self.seed_everything(seed=self.args.seed)
self.meters = self._setup_meters()
self.progress, self.total_progress = self._setup_progress()
self.model, self.trainables, self.optim_states, self.scheduler_states = self._setup_models()
self.train_dataset_length, self.train_sampler, self.train_loader, self.valid_loader = self._setup_dataloader()
if self.args.num_steps != None:
self.total_step = self.args.num_steps
self.args.num_epochs = math.ceil(self.total_step / math.floor(self.train_dataset_length / self.args.batch_size)) if not self.args.dynamic_batching else None
else:
self.total_step = int(math.floor(self.train_dataset_length / self.args.batch_size))*self.args.num_epochs
self.optimizer, self.scheduler = self._setup_optimizer()
self.scaler = torch.cuda.amp.GradScaler()
self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[self.rank], find_unused_parameters=False)
if self.rank == 0:
self.early_stop_accu_steps = 0
if self.args.dynamic_batching:
logging.info(f"max number of tokens per GPU in a training batch: {self.args.max_num_tokens}, max number of tokens per GPU in a inference batch: {self.args.val_max_num_tokens}")
else:
logging.info(f"batch size (summed over all GPUs): {self.args.batch_size}")
def train(self):
flag = True
skip_flag = False
data_start_time = time.time()
while flag:
self.train_sampler.set_epoch(self.progress['epoch'])
for i, batch in enumerate(self.train_loader):
data_end_time = time.time()
self.model.train()
if self.progress['step'] > self.total_step:
flag = False
self.validate_and_save()
if self.rank == 0:
self.writer.close()
break
if isinstance(self.scheduler, Eden):
self.scheduler.step_epoch(self.progress['step']//self.args.pseudo_epoch_size + 1)
if self.args.optimizer_name == "ScaledAdam":
cur_lr = self.scheduler.get_last_lr()[0]
else:
lrs = [param_group['lr'] for param_group in self.optimizer.param_groups]
assert lrs[0] == lrs[1]
cur_lr = lrs[0]
if self.rank == 0 and self.progress['step'] % self.args.tb_write_every_n_steps == 0:
self.writer.add_scalar("train/lr", cur_lr, self.progress['step'])
self.wandb.log({"train/lr": cur_lr}, step=self.progress['step'])
all_inds = list(range(len(batch['y'])))
sum_losses = 0
sum_top10acc = 0
sum_ntoken = 0
sum_top10acc_cbi = [0 for _ in range(self.args.n_codebooks)]
for j in range(self.args.gradient_accumulation_steps):
cur_ind = all_inds[j::self.args.gradient_accumulation_steps]
cur_batch = {key: batch[key][cur_ind] for key in batch}
with torch.cuda.amp.autocast(dtype=torch.float16 if self.args.precision=="float16" else torch.float32):
out = self.model(cur_batch)
record_loss = out['loss'].detach().to(self.rank)
top10acc = out['top10acc'].to(self.rank)
effective_ntoken = out['effective_ntoken'].to(self.rank)
is_nan = torch.tensor(int(torch.isnan(record_loss).any()), dtype=torch.float32, device=self.rank)
dist.all_reduce(record_loss, op=dist.ReduceOp.SUM)
dist.all_reduce(top10acc, op=dist.ReduceOp.SUM)
dist.all_reduce(effective_ntoken, op=dist.ReduceOp.SUM)
dist.all_reduce(is_nan, op=dist.ReduceOp.SUM)
# check if loss is nan
if is_nan.item() > 0:
logging.info(f"loss at step {self.progress['step']} is nan, therefore skip this batch")
skip_flag = True
continue
sum_losses += record_loss.item()
sum_top10acc += top10acc.item()
sum_ntoken += effective_ntoken.item()
if 'top10acc_by_codebook' in out:
for cb in range(self.args.n_codebooks):
top10acc_cbi = out['top10acc_by_codebook'][cb]
dist.all_reduce(top10acc_cbi, op=dist.ReduceOp.SUM)
sum_top10acc_cbi[cb] += top10acc_cbi.item()
if self.rank == 0:
average_loss = sum_losses / sum_ntoken
average_top10acc = sum_top10acc / sum_ntoken
self.meters['train_loss'].update(average_loss, batch['x'].shape[0]*self.world_size)
self.meters['train_top10acc'].update(average_top10acc, batch['x'].shape[0]*self.world_size)
self.meters['train_top10acc'].update(average_top10acc, batch['x'].shape[0]*self.world_size)
average_top10acc_cbi = [sum_top10acc_cbi[cb] / sum_ntoken * self.args.n_codebooks for cb in range(self.args.n_codebooks)]
for cb in range(self.args.n_codebooks):
self.meters[f'train_top10acc_cb{cb+1}'].update(average_top10acc_cbi[cb], batch['x'].shape[0]*self.world_size)
if self.progress['step'] % self.args.tb_write_every_n_steps == 0:
self.writer.add_scalar('train/loss', average_loss, self.progress['step'])
self.writer.add_scalar('train/top10acc', average_top10acc, self.progress['step'])
self.writer.add_scalar("train/ntokens", sum_ntoken, self.progress['step'])
for cb in range(self.args.n_codebooks):
self.writer.add_scalar(f'train/top10acc_cb{cb+1}', average_top10acc_cbi[cb], self.progress['step'])
if self.args.optimizer_name == "ScaledAdam":
self.scaler.scale(out['loss']).backward()
else:
self.scaler.scale(out['loss']/out['effective_ntoken']).backward()
if skip_flag:
self.optimizer.zero_grad()
skip_flag = False
continue
if self.args.optimizer_name != "ScaledAdam":
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.gradient_clip_val)
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
if self.args.optimizer_name == "ScaledAdam":
self.scheduler.step_batch(self.progress['step'])
else:
self.scheduler.step()
if self.rank == 0:
self.meters['data_time'].update(data_end_time - data_start_time)
self.meters['train_time'].update(time.time() - data_end_time)
if self.progress['step'] % self.args.tb_write_every_n_steps == 0:
self.writer.add_scalar("train/data_time", data_end_time - data_start_time, self.progress['step'])
self.writer.add_scalar("train/train_time", time.time() - data_end_time, self.progress['step'])
# logging
if self.progress['step'] % self.args.print_every_n_steps == 0:
log_out = {}
log_out['cur_epoch'] = f"{self.progress['epoch']}/{self.args.num_epochs}" if self.args.num_epochs is not None else f"{self.progress['epoch']}"
log_out['cur_step'] = f"{int(self.progress['cur_step']+1)}"
log_out['total_step'] = f"{self.progress['step']}/{self.args.num_steps}"
log_out['lr'] = f"{cur_lr:.7f}"
log_out['ntokens'] = f"{sum_ntoken}"
for key in self.meters:
if self.meters[key].val != 0 or self.meters[key].avg != 0:
log_out[key] = f"{self.meters[key].val:.4f} ({self.meters[key].avg:.4f})" if isinstance(self.meters[key].val, float) else f"{self.meters[key].val}"
logging.info(log_out)
if np.isnan(self.meters['train_loss'].avg):
logging.warning("training diverged...")
raise RuntimeError("training diverged...")
# validation and save models
if self.progress['step'] % self.args.val_every_n_steps == 0:
dist.barrier()
self.validate_and_save()
self.progress['step'] += 1
self.progress['cur_step'] += 1
data_start_time = time.time()
self.progress['epoch'] += 1
self.progress['cur_step'] = 0 # reset cur_step to be 0
dist.destroy_process_group()
def validate_and_save(self):
self.model.eval()
score = self.validate(self.valid_loader)
if self.rank == 0:
if self.args.early_stop_threshold > 0:
if self.progress['best_score'] - score < self.args.early_stop_threshold:
self.early_stop_accu_steps += self.args.val_every_n_steps
if self.early_stop_accu_steps >= self.args.early_stop_step-1:
logging.info(f"early stop based on self.args.early_stop_threshold: {self.args.early_stop_threshold}, and self.args.early_stop_step: {self.args.early_stop_step}")
logging.info(f"best validation score at step: {self.progress['best_step']}, and the score is {self.progress['best_score']:.4f}")
dist.destroy_process_group()
raise RuntimeError("early stop")
else:
self.early_stop_accu_steps = 0
if (score < self.progress['best_score']):
self.progress['best_step'] = self.progress['step']
self.progress['best_score'] = score
save_path = os.path.join(self.args.exp_dir,"best_bundle.pth")
torch.save(
{
"model": self.model.module.state_dict(),
"optimizer": self.optimizer.state_dict(),
"scheduler": self.scheduler.state_dict(),
"config": self.args,
"phn2num": self.train_loader.dataset.phn2num
},save_path
)
logging.info(f"save *best* models at {save_path} at global step {self.progress['step']}")
self._save_progress()
save_path = os.path.join(self.args.exp_dir,"bundle.pth")
torch.save(
{
"model": self.model.module.state_dict(),
"optimizer": self.optimizer.state_dict(),
"scheduler": self.scheduler.state_dict(),
"config": self.args,
"phn2num": self.train_loader.dataset.phn2num
},save_path
)
logging.info(f"save models, indices, acc and other statistics at {save_path} and {self.args.exp_dir}/progress.pkl at global step {self.progress['step']}")
dist.barrier()
def validate(self, valid_loader=None, hide_progress=True):
if valid_loader == None:
valid_loader = self.valid_loader
self.model.eval()
start_val_time = time.time()
sum_losses = 0
sum_top10acc = 0
sum_ntoken = 0
sum_top10acc_cbi = [0 for _ in range(self.args.n_codebooks)]
with torch.no_grad():
for i, batch in enumerate(tqdm(valid_loader, disable=hide_progress)):
out = self.model(batch)
sum_losses += out['loss']
sum_top10acc += out['top10acc']
sum_ntoken += out['effective_ntoken']
if 'top10acc_by_codebook' in out:
for cb in range(self.args.n_codebooks):
sum_top10acc_cbi[cb] += out['top10acc_by_codebook'][cb]
dist.all_reduce(sum_losses, op=dist.ReduceOp.SUM)
dist.all_reduce(sum_top10acc, op=dist.ReduceOp.SUM)
dist.all_reduce(sum_ntoken, op=dist.ReduceOp.SUM)
if 'top10acc_by_codebook' in out:
for cb in range(self.args.n_codebooks):
dist.all_reduce(sum_top10acc_cbi[cb], op=dist.ReduceOp.SUM)
if self.rank == 0:
val_loss = sum_losses / sum_ntoken
val_top10acc = sum_top10acc / sum_ntoken
# logging
self.meters['val_loss'].update(val_loss)
logging.info(f"val loss: {val_loss:.5f}")
self.writer.add_scalar("val/loss", val_loss, self.progress['step'])
self.meters['val_top10acc'].update(val_top10acc)
logging.info(f"val top10acc: {val_top10acc:.5f}")
self.writer.add_scalar("val/top10acc", val_top10acc, self.progress['step'])
for cb in range(self.args.n_codebooks):
average_top10acc_cbi = sum_top10acc_cbi[cb] / sum_ntoken * self.args.n_codebooks
self.meters[f'val_top10acc_cb{cb+1}'].update(average_top10acc_cbi)
self.writer.add_scalar(f'val/top10acc_cb{cb+1}', average_top10acc_cbi, self.progress['step'])
logging.info(f"validation takes: {time.time() - start_val_time:.2f}s")
logging.info(f"Step [{self.progress['step']}/{self.total_step}]\t Time elapsed {(time.time() - self.start_time)/3600.:.2f}h, Val Loss: {val_loss:.4f}, Val Top10Acc: {val_top10acc:.4f}")
return val_loss.item()
else:
return None
def _setup_meters(self):
meters = {}
meter_names = ['train_loss', 'val_loss', 'train_top10acc', 'val_top10acc', 'data_time', 'train_time']
meter_names += ['train_dur_loss', 'train_dur_acc', 'val_dur_loss', 'val_dur_acc']
meter_names += [f'train_top10acc_cb{cb+1}' for cb in range(self.args.n_codebooks)]
meter_names += [f'val_top10acc_cb{cb+1}' for cb in range(self.args.n_codebooks)]
for name in meter_names:
meters[name] = AverageMeter()
return meters
def _setup_progress(self):
progress = {}
progress['best_step'] = 1
progress['best_score'] = np.inf # this records loss value
progress['step'] = 1
progress['epoch'] = 1
progress['cur_step'] = 0 # step in the current epoch, for resuming the sampler
total_progress = []
# if self.args.resume or self.args.validate:
if self.args.resume:
progress_pkl = "%s/progress.pkl" % self.args.exp_dir
with open(progress_pkl, "rb") as f:
total_progress = pickle.load(f)
progress['best_step'], progress['best_score'], progress['step'], progress['epoch'], progress['cur_step'], _ = total_progress[-1]
if self.rank == 0:
logging.info("\nResume training from:")
logging.info(" epoch = %s" % progress['epoch'])
logging.info(" cur_step = %s" % progress['cur_step'])
logging.info(" step = %s" % progress['step'])
logging.info(" best_step = %s" % progress['best_step'])
logging.info(" best_score = %s" % progress['best_score'])
return progress, total_progress
def _save_progress(self):
self.total_progress.append([self.progress['best_step'], self.progress['best_score'], int(self.progress['step']+1), self.progress['epoch'], int(self.progress['cur_step']+1), time.time() - self.start_time])
with open("%s/progress.pkl" % self.args.exp_dir, "wb") as f:
pickle.dump(self.total_progress, f)
def _setup_dataloader(self):
assert self.args.dataset == 'gigaspeech', "only gigaspeech is supported for now"
train_dataset, val_dataset = gigaspeech.dataset(self.args, 'train'), gigaspeech.dataset(self.args, 'validation')
if self.args.dynamic_batching:
train_sampler = DistributedDynamicBatchSampler(train_dataset, self.args, num_replicas=self.world_size, rank=self.rank, shuffle=True, seed=self.args.seed, drop_last=True, lengths_list=train_dataset.lengths_list, verbose=True, epoch=0)
valid_sampler = DistributedDynamicBatchSampler(val_dataset, self.args, num_replicas=self.world_size, rank=self.rank, shuffle=True, seed=self.args.seed, drop_last=True, lengths_list=val_dataset.lengths_list, verbose=True, epoch=0)
else:
train_sampler = StatefulDistributedSampler(train_dataset, self.args.batch_size//self.world_size, num_replicas=self.world_size, rank=self.rank, shuffle=True, seed=self.args.seed, drop_last=True)
valid_sampler = DistributedSampler(val_dataset, num_replicas=self.world_size, rank=self.rank, shuffle=False, seed=self.args.seed, drop_last=False)
if self.progress['step'] > 1:
train_sampler.set_epoch_resume(self.progress['epoch'], self.progress['cur_step'])
if self.args.dynamic_batching:
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_sampler=train_sampler,
num_workers=self.args.num_workers//self.world_size,
collate_fn=train_dataset.collate, persistent_workers=True
)
valid_loader = torch.utils.data.DataLoader(val_dataset,
batch_sampler=valid_sampler,
num_workers=self.args.num_workers//self.world_size,
collate_fn=val_dataset.collate, persistent_workers=True
)
else:
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=self.args.batch_size//self.world_size, sampler=train_sampler, num_workers=self.args.num_workers//self.world_size,
collate_fn=train_dataset.collate, persistent_workers=True
)
valid_loader = torch.utils.data.DataLoader(val_dataset,
batch_size=self.args.batch_size//self.world_size, sampler=valid_sampler,
num_workers=self.args.num_workers//self.world_size,
collate_fn=val_dataset.collate, persistent_workers=True
)
return len(train_dataset), train_sampler, train_loader, valid_loader
def _setup_models(self):
model = voicecraft.VoiceCraft(self.args)
if self.rank == 0:
logging.info(model)
logging.info("model parameters")
print_model_info(model)
if self.progress['step'] > 1:
bundle = torch.load(os.path.join(self.args.exp_dir, "bundle.pth"), map_location="cpu")
model.load_state_dict(bundle['model'])
optim_states = bundle['optimizer']
scheduler_states = bundle['scheduler']
if self.rank == 0:
logging.info("loaded parameters and data indices from epoch %d, global step %d" % (self.progress['epoch'], self.progress['step']))
del bundle['model']
else:
optim_states = None
scheduler_states = None
if self.args.load_model_from != None and self.progress['step'] <= 1:
sd = torch.load(self.args.load_model_from, map_location="cpu")['model']
model.load_state_dict(sd)
del sd
if self.args.optimizer_name == "ScaledAdam":
trainables = [p for p in model.parameters() if p.requires_grad]
else:
no_decay = [".bias", ".audio_embeddings.weight", ".text_embeddings.weight", ".norm.weight", ".norm1.weight", ".norm2.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad],
"weight_decay": self.args.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad],
"weight_decay": 0.0,
},
]
if len(optimizer_grouped_parameters[1]['params']) == 0:
logging.info("there is no embedding weights, bias, and layernorm parameters in the model, which should be True, check model parameter names")
trainables = optimizer_grouped_parameters[0]
else:
trainables = optimizer_grouped_parameters
model.to(self.device)
return model, trainables, optim_states, scheduler_states
def _setup_optimizer(self):
if self.args.optimizer_name == "ScaledAdam":
parameters_names = []
parameters_names.append([n for n,p in self.model.named_parameters() if p.requires_grad])
optimizer = ScaledAdam(
self.trainables,
lr=self.args.lr,
betas=(0.9, 0.95),
clipping_scale=2.0,
parameters_names=parameters_names,
show_dominant_parameters=False,
clipping_update_period=self.args.clipping_update_period,
)
scheduler = Eden(optimizer, self.args.reduce_lr_start_step, self.args.reduce_lr_start_epoch, warmup_batches=self.total_step * self.args.warmup_fraction)
else:
optimizer = AdamW(self.trainables, lr=self.args.lr)
warmup_steps = self.total_step * self.args.warmup_fraction
def lr_lambda(current_step: int):
if current_step < warmup_steps:
return float(current_step) / float(max(1, warmup_steps))
return max(
0.0, float(self.total_step - current_step) / float(max(1, self.total_step - warmup_steps))
)
scheduler = LambdaLR(optimizer, lr_lambda, last_epoch=-1)
# if resume
if self.progress['step'] > 1:
optimizer.load_state_dict(self.optim_states)
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.cuda()
del self.optim_states
scheduler.load_state_dict(self.scheduler_states)
optimizer.zero_grad()
return optimizer, scheduler
def seed_everything(self, seed=1):
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

628
steps/trainer_utils.py Normal file
View File

@ -0,0 +1,628 @@
import torch
import math
import torch.distributed as dist
from torch.utils.data.sampler import Sampler
import copy
import numpy as np
from typing import List
from scipy.stats import lognorm
import logging
class StatefulDistributedSampler(Sampler[int]):
def __init__(self, dataset, batch_size, num_replicas = None, rank = None, shuffle = True, seed = 0, drop_last = False):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
if rank >= num_replicas or rank < 0:
raise ValueError(
"Invalid rank {}, rank should be in the interval"
" [0, {}]".format(rank, num_replicas - 1))
self.dataset = dataset
self.batch_size = batch_size
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.cur_epoch = 0
self.drop_last = drop_last
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil(
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
self.seed = seed
self.continue_flag = False
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
r"""
Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
use a different random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
else:
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
if not self.drop_last:
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
self.indices = indices
if self.continue_flag:
self.indices = self.indices[int(self.cur_step*self.batch_size):]
self.num_samples = len(self.indices)
self.continue_flag = False
def __iter__(self):
for idx in self.indices:
yield idx
def set_epoch_resume(self, epoch, cur_step):
self.epoch = epoch
self.cur_step = cur_step
self.continue_flag = True
class StatefulSampler(Sampler):
def __init__(self, data_source_length, batch_size, use_random=True, seed=1, epoch=0):
self.use_random = use_random
self.data_source_length = data_source_length
self.num_samples = self.data_source_length
self.batch_size = batch_size
self.continue_flag = False
self.seed = seed
self.epoch = epoch
self.cur_step = 0
def __len__(self):
return self.num_samples
def __iter__(self):
for idx in self.indices:
yield idx
def set_epoch(self, epoch):
self.epoch = epoch
if self.use_random:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
self.indices = torch.randperm(self.data_source_length, generator=g).tolist() # type: ignore[arg-type]
else:
self.indices = list(range(self.data_source_length)) # type: ignore[arg-type]
if self.continue_flag == True:
self.continue_flag = False
self.indices = self.indices[int(self.cur_step*self.batch_size):]
self.num_samples = len(self.indices)
def set_epoch_resume(self, epoch, cur_step):
self.epoch = epoch
self.cur_step = cur_step
self.continue_flag = True
class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def print_model_info(model, print_model = False, print_params = True):
if print_model:
logging.info(model)
if print_params:
all_params = {}
for name, p in model.named_parameters():
name = name.split(".")[0]
if name in all_params:
all_params[name] += p.numel()
else:
all_params[name] = p.numel()
logging.info("num of parameters of each components:")
for name in all_params:
logging.info(f"{name}: {all_params[name]/1000000.:.2f}m")
class DistributedDynamicBatchSampler(Sampler):
"""
modified from SpeechBrian, https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/dataio/sampler.py#L307
This BatchSampler batches examples together by grouping them by their length.
Every example in the batch have approximately the same length and
thus padding is minimized.
This enables faster training on datasets
where length of examples can vary significantly (e.g Librispeech).
Inspired by: https://www.tensorflow.org/api_docs/python/tf/data/experimental/bucket_by_sequence_length
Dynamic batching is performed by specifying a max_batch_length which is the
upper limit for the sum of the length of examples in a batch:
e.g., if ex1 has length 4, ex2 length 5 and if max_batch_length is set to 6
ex1 and ex2 will be placed, alone, in two distinct batches.
Length for each example can be obtained in two manners.
If the input dataset is a DynamicItemDataset it can be obtained by specifying a
length_func. Default assumes a "duration" entry is in the annotation.
Length for each example can also be passed to this class upon instantiation
by specifying a list containing the length for each example and passing it to
lengths_list.
Examples are grouped together by defining a set of possible discrete intervals
(buckets). Examples whose length fall into these intervals can be batched together.
The number of buckets can be specified by using the arg num_buckets.
There is usually an optimal range for the value of this argument.
If num_buckets == 1, all examples can be batched together. You have maximum randomization
but your training speed will be slower due to the fact that a large amount of the values will be padding
as long and short examples can be batched together.
As the number of buckets grows only examples with similar
length can be grouped together.
This trades-off speed with randomization.
TLDR: Low number -> better randomization, High number -> faster training.
NOTE THAT: if set too high the training speed will decrease. If num_buckets -> number of examples in the dataset the batch size
will be small impacting training speed and possibly performance.
The buckets can also be specified by passing a list to the bucket_boundaries
argument instead of specifying a left_bucket_length and a bucket_length_multiplier.
Example
-------
>>> import torch
>>> import speechbrain as sb
>>> from speechbrain.dataio.sampler import DynamicBatchSampler
>>> from speechbrain.dataio.dataset import DynamicItemDataset
>>> from speechbrain.dataio.dataloader import SaveableDataLoader
>>> from speechbrain.dataio.batch import PaddedBatch
>>> import numpy as np
>>> item_lengths = sorted([np.random.randint(10, 100) for x in range(20)])
>>> dataset = {"ex_{}".format(x) : {"wav" :torch.randn(x)} for x in item_lengths}
>>> dataset = DynamicItemDataset(dataset)
>>> dataset.set_output_keys(["wav"])
>>> length_func = lambda x : len(x) # trivial in this example
>>> bsampler = DynamicBatchSampler(dataset, 20, 4, length_func, shuffle=False, batch_ordering='descending')
>>> dataloader = SaveableDataLoader(dataset, batch_sampler=bsampler, collate_fn=PaddedBatch)
>>> for i, b in enumerate(dataloader):
... data, length = b["wav"]
>>> assert data.shape[-1] == max(item_lengths)
Arguments
---------
dataset : torch.utils.data.Dataset
Pytorch Dataset from which elements will be sampled.
max_batch_length : int
Upper limit for the sum of the length of examples in a batch.
Should be chosen based on your GPU memory.
num_buckets : int
Number of discrete buckets used to group examples together.
If num_buckets == 1, all examples can be batched together. As the number of buckets grows only examples with similar
length can be grouped together. This trades-off speed with randomization.
Low number -> better randomization, High number -> faster training.
However if set too high the training speed will decrease. If num_buckets -> number of examples in the dataset the batch size
will be small impacting training speed and possibly performance.
NOTE: you have either to specify manually the bucket_boundaries or the number of buckets.
length_func : callable
Function used to get length of each example from the dataset.
This argument can be used only when the dataset is a Speechbrain DynamicItemDataset object.
Can be anything: e.g. lambda x: x["duration"]*16000 returns number of samples
if duration key in the annotation is in seconds and the file has 16kHz sampling freq.
shuffle : bool
Whether or not shuffle examples between each epoch.
batch_ordering : string
If ``random``, batches are randomly permuted; otherwise ``ascending`` or ``descending`` sorted by length.
max_batch_ex: int
If set, it limits the maximum number of examples that can be in a batch superseeding max_batch_length
in instances where the amount of examples will exceeed the value specified here.
E.g. you have a lot of short examples and the batch size for those will be too high, you can use this argument
to limit the batch size for these short examples.
bucket_boundaries : list
Overrides bucket_length_multiplier and left_bucket_length by specifying manually
the buckets right boundaries.
lengths_list: list
Overrides length_func by passing a list containing the length of each example
in the dataset. This argument must be set when the dataset is a plain
Pytorch Dataset object and not a DynamicItemDataset object as length_func
cannot be used on Pytorch Datasets.
epoch : int
The epoch to start at.
drop_last : bool
If ``True``, the sampler will drop the last examples which
have not been grouped.
verbose: bool
If ``True``, log also the stats for each batch at the first epoch.
"""
def __init__(
self,
dataset,
args,
num_replicas = None,
rank = None,
shuffle = True,
seed = 0,
drop_last = False,
length_func=lambda x: x["duration"],
batch_ordering: str = "random",
max_batch_ex: int = None,
bucket_boundaries: List[int] = [],
lengths_list: List[int] = None,
epoch: int = 0,
verbose: bool = False,
):
self.args = args
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
if rank >= num_replicas or rank < 0:
raise ValueError(
"Invalid rank {}, rank should be in the interval"
" [0, {}]".format(rank, num_replicas - 1))
self.num_replicas = num_replicas
self.rank = rank
max_batch_length = self.args.max_num_tokens if dataset.split == "train" else self.args.val_max_num_tokens
logging.info(f"max_num_tokens per GPU for {dataset.split} split: {max_batch_length}")
num_buckets = self.args.num_buckets
#############
self._dataset = dataset
self._ex_lengths = {}
# ex_ids = self._dataset.data_ids
self.verbose = verbose
# We do not put a default on num_buckets to encourage users to play with this parameter
if num_buckets is None and len(bucket_boundaries) == 0:
raise RuntimeError(
"Please specify either num_buckets or bucket boundaries."
"Check the docs, and/or the tutorial !"
)
assert lengths_list != None
max_len = int(self.args.audio_max_length * self.args.encodec_sr)
lengths_list = [min(l, max_len) for l in lengths_list] # replace all utt whose length is longer than max_len to max_len, will also do this in __getitem__ in dataset
for indx in range(len(lengths_list)):
self._ex_lengths[str(indx)] = lengths_list[indx]
# if lengths_list is not None:
# # take length of examples from this argument and bypass length_key
# for indx in range(len(lengths_list)):
# self._ex_lengths[str(indx)] = lengths_list[indx]
# else:
# # use length func
# if not isinstance(dataset, DynamicItemDataset):
# raise NotImplementedError(
# "Dataset should be a Speechbrain DynamicItemDataset when using length function"
# )
# for indx in range(len(self._dataset)):
# self._ex_lengths[str(indx)] = length_func(
# self._dataset.data[ex_ids[indx]]
# )
if len(bucket_boundaries) > 0:
if not all([x >= 0 for x in bucket_boundaries]):
raise ValueError(
"All elements in bucket boundaries should be non-negative (>= 0)."
)
if not len(set(bucket_boundaries)) == len(bucket_boundaries):
raise ValueError(
"Bucket_boundaries should not contain duplicates."
)
np.testing.assert_array_equal(
np.array(bucket_boundaries),
np.array(sorted(bucket_boundaries)),
err_msg="The arg bucket_boundaries should be an ascending sorted list of non negative values values!",
)
self._bucket_boundaries = np.array(sorted(bucket_boundaries))
else:
# use num_buckets
self._bucket_boundaries = np.array(
self._get_boundaries_through_warping(
# max_batch_length=max_batch_length,
max_batch_length=max(lengths_list),
num_quantiles=num_buckets,
)
)
self._max_batch_length = max_batch_length
self._shuffle_ex = shuffle
self._batch_ordering = batch_ordering
self._seed = seed
self._drop_last = drop_last
if max_batch_ex is None:
max_batch_ex = np.inf
self._max_batch_ex = max_batch_ex
# Calculate bucket lengths - how often does one bucket boundary fit into max_batch_length?
self._bucket_lens = [
max(1, int(max_batch_length / self._bucket_boundaries[i]))
for i in range(len(self._bucket_boundaries))
] + [1]
self._epoch = epoch
self._cur_step = 0
self.continue_flag = False
self._generate_batches()
self.num_samples = int(math.floor(len(self._batches) / self.num_replicas))
self.total_size = int(self.num_samples * self.num_replicas)
self._replica_batches = self._batches[self.rank:self.total_size:self.num_replicas]
assert len(self._replica_batches) == self.num_samples, f"len(self._batches): {len(self._batches)}, self.total_size: {self.total_size}, self.num_samples: {self.num_samples},len(self._replica_batches): {len(self._replica_batches)}"
logging.info(f"len(self._batches): {len(self._batches)}")
logging.info(f"self.num_replicas: {self.num_replicas}")
logging.info(f"num of batches on each replica: {self.num_samples}")
def get_durations(self, batch):
"""Gets durations of the elements in the batch."""
return [self._ex_lengths[str(idx)] for idx in batch]
def _get_boundaries_through_warping(
self, max_batch_length: int, num_quantiles: int,
) -> List[int]:
# NOTE: the following lines do not cover that there is only one example in the dataset
# warp frames (duration) distribution of train data
logging.info("Batch quantisation in latent space")
# linspace set-up
num_boundaries = num_quantiles + 1
# create latent linearly equal spaced buckets
latent_boundaries = np.linspace(
1 / num_boundaries, num_quantiles / num_boundaries, num_quantiles,
)
# get quantiles using lognormal distribution
quantiles = lognorm.ppf(latent_boundaries, 1)
# scale up to to max_batch_length
bucket_boundaries = quantiles * max_batch_length / quantiles[-1]
# compute resulting bucket length multipliers
length_multipliers = [
bucket_boundaries[x + 1] / bucket_boundaries[x]
for x in range(num_quantiles - 1)
]
# logging
logging.debug(
"Latent bucket boundary - buckets: {} - length multipliers: {}".format(
list(map("{:.2f}".format, bucket_boundaries)),
list(map("{:.2f}".format, length_multipliers)),
)
)
return list(sorted(bucket_boundaries))
def _permute_batches(self):
if self._batch_ordering == "random":
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self._seed + self._epoch) # since the random seed is based on self._seed and self._epoch, it should be the same for different processes when using DDP, and therefore the generated order should be the same across different process, this is important, because each replica will only take a portion of it, we want to make sure they take a non-overlapping portion, and all of them constitute the entire dataset
sampler = torch.randperm(
len(self._batches), generator=g
).tolist() # type: ignore
tmp = []
for idx in sampler:
tmp.append(self._batches[idx])
self._batches = tmp
elif self._batch_ordering == "ascending":
self._batches = sorted(
self._batches,
key=lambda x: max([self._ex_lengths[str(idx)] for idx in x]),
)
elif self._batch_ordering == "descending":
self._batches = sorted(
self._batches,
key=lambda x: max([self._ex_lengths[str(idx)] for idx in x]),
reverse=True,
)
else:
raise NotImplementedError
def _generate_batches(self):
logging.info("DynamicBatchSampler: Generating dynamic batches")
if self._shuffle_ex:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self._seed + self._epoch) # since the random seed is based on self._seed and self._epoch, it should be the same for different processes when using DDP, and therefore the generated order should be the same across different process, this is important, because each replica will only take a portion of it, we want to make sure they take a non-overlapping portion, and all of them constitute the entire dataset
sampler = torch.randperm(len(self._dataset), generator=g).tolist() # type: ignore
# pyp note: this is actually randomly permoted indices
else:
# take examples as they are: e.g. they have been sorted
sampler = range(len(self._dataset)) # type: ignore
self._batches = []
bucket_batches = [[] for i in self._bucket_lens]
stats_tracker = [
{"min": np.inf, "max": -np.inf, "tot": 0, "n_ex": 0}
for i in self._bucket_lens
]
for idx in sampler:
# length of pre-sampled audio
item_len = self._ex_lengths[str(idx)]
# bucket to fill up most padding
bucket_id = np.searchsorted(self._bucket_boundaries, item_len)
# fill audio's duration into that bucket
bucket_batches[bucket_id].append(idx)
stats_tracker[bucket_id]["min"] = min(
stats_tracker[bucket_id]["min"], item_len
)
stats_tracker[bucket_id]["max"] = max(
stats_tracker[bucket_id]["max"], item_len
)
stats_tracker[bucket_id]["tot"] += item_len
stats_tracker[bucket_id]["n_ex"] += 1
# track #samples - why not duration/#frames; rounded up?
# keep track of durations, if necessary
if (
len(bucket_batches[bucket_id]) >= self._bucket_lens[bucket_id]
or len(bucket_batches[bucket_id]) >= self._max_batch_ex
):
self._batches.append(bucket_batches[bucket_id])
bucket_batches[bucket_id] = []
# keep track of durations
# Dump remaining batches
if not self._drop_last:
for batch in bucket_batches:
if batch:
self._batches.append(batch)
self._permute_batches() # possibly reorder batches
if self._epoch == 0: # only log at first epoch
# frames per batch & their padding remaining
boundaries = [0] + self._bucket_boundaries.tolist()
for bucket_indx in range(len(self._bucket_boundaries)):
try:
num_batches = stats_tracker[bucket_indx]["tot"] // (
self._max_batch_length
)
pad_factor = (
stats_tracker[bucket_indx]["max"]
- stats_tracker[bucket_indx]["min"]
) / (
stats_tracker[bucket_indx]["tot"]
/ stats_tracker[bucket_indx]["n_ex"]
)
except ZeroDivisionError:
num_batches = 0
pad_factor = 0
logging.debug(
(
"DynamicBatchSampler: Bucket {} with boundary {:.1f}-{:.1f} and "
+ "batch_size {}: Num Examples {:.1f}, Num Full Batches {:.3f}, Pad Factor {:.3f}."
).format(
bucket_indx,
boundaries[bucket_indx],
boundaries[bucket_indx + 1],
self._bucket_lens[bucket_indx],
stats_tracker[bucket_indx]["n_ex"],
num_batches,
pad_factor * 100,
)
)
if self.verbose:
batch_stats = {
"tot_frames": [],
"tot_pad_frames": [],
"pad_%": [],
}
for batch in self._batches:
tot_frames = sum(
[self._ex_lengths[str(idx)] for idx in batch]
)
batch_stats["tot_frames"].append(tot_frames)
max_frames = max(
[self._ex_lengths[str(idx)] for idx in batch]
)
tot_pad = sum(
[
max_frames - self._ex_lengths[str(idx)]
for idx in batch
]
)
batch_stats["tot_pad_frames"].append(tot_pad)
batch_stats["pad_%"].append(tot_pad / tot_frames * 100)
padding_details = "Batch {} with {:.1f} frames with {} files - {:.1f} padding, {:.2f} (%) of total."
padding_details = "DynamicBatchSampler: " + padding_details
for i in range(len(self._batches)):
logging.debug(
padding_details.format(
i,
batch_stats["tot_frames"][i],
len(self._batches[i]),
batch_stats["tot_pad_frames"][i],
batch_stats["pad_%"][i],
)
)
def __iter__(self):
for batch in self._replica_batches:
yield batch
# if self._shuffle_ex: # re-generate examples if ex_ordering == "random"
# self._generate_batches()
# if self._batch_ordering == "random":
# # we randomly permute the batches only --> faster
# self._permute_batches()
def set_epoch(self, epoch):
"""
You can also just access self.epoch, but we maintain this interface
to mirror torch.utils.data.distributed.DistributedSampler
"""
self._epoch = epoch
self._generate_batches()
self._replica_batches = self._batches[self.rank:self.total_size:self.num_replicas]
self.num_samples = int(math.floor(len(self._batches) / self.num_replicas))
assert len(self._replica_batches) == self.num_samples, f"len(self._batches): {len(self._batches)}, self.total_size: {self.total_size}, self.num_samples: {self.num_samples},len(self._replica_batches): {len(self._replica_batches)}"
if self.continue_flag:
self.continue_flag = False
self._replica_batches = self._replica_batches[self._cur_step:]
self.num_samples = len(self._replica_batches)
def __len__(self):
return self.num_samples
def set_epoch_resume(self, epoch, cur_step):
self.continue_flag = True
self._epoch = epoch
self._cur_step = cur_step

71
z_scripts/e830M.sh Normal file
View File

@ -0,0 +1,71 @@
#!/bin/bash
source ~/miniconda3/etc/profile.d/conda.sh
conda activate voicecraft
export CUDA_VISIBLE_DEVICES=0,1,2,3
export WORLD_SIZE=4
dataset=gigaspeech
mkdir -p ./logs/${dataset}
exp_root="/data/scratch/pyp/exp_pyp/VoiceCraft"
exp_name=e830M
dataset_dir="/data/scratch/pyp/datasets/gigaspeech_phn_enc_manifest/xl"
encodec_codes_folder_name="encodec_16khz_4codebooks"
# export CUDA_LAUNCH_BLOCKING=1 # for debugging
torchrun --nnodes=1 --rdzv-backend=c10d --rdzv-endpoint=localhost:41977 --nproc_per_node=${WORLD_SIZE} \
../main.py \
--reduced_eog 1 \
--drop_long 1 \
--eos 2051 \
--n_special 4 \
--pad_x 0 \
--codebook_weight "[5,1,0.5,0.1]" \
--encodec_sr 50 \
--num_steps 50000 \
--lr 0.05 \
--warmup_fraction 0.01 \
--optimizer_name "ScaledAdam" \
--pseudo_epoch_size 3000 \
--reduce_lr_start_step 3000 \
--reduce_lr_start_epoch 4 \
--clipping_update_period 1000 \
--d_model 2048 \
--audio_embedding_dim 2048 \
--nhead 16 \
--num_decoder_layers 16 \
--max_num_tokens 100000 \
--gradient_accumulation_steps 26 \
--val_max_num_tokens 6000 \
--num_buckets 6 \
--audio_max_length 20 \
--audio_min_length 2 \
--text_max_length 400 \
--text_min_length 10 \
--mask_len_min 1 \
--mask_len_max 600 \
--tb_write_every_n_steps 10 \
--print_every_n_steps 400 \
--val_every_n_steps 1600 \
--text_vocab_size 100 \
--text_pad_token 100 \
--phn_folder_name "phonemes" \
--manifest_name "manifest_large16khz_lessambi" \
--encodec_folder_name ${encodec_codes_folder_name} \
--audio_vocab_size 2048 \
--empty_token 2048 \
--eog 2049 \
--audio_pad_token 2050 \
--n_codebooks 4 \
--max_n_spans 3 \
--shuffle_mask_embedding 0 \
--mask_sample_dist poisson1 \
--max_mask_portion 0.9 \
--min_gap 5 \
--num_workers 8 \
--dynamic_batching 1 \
--dataset $dataset \
--exp_dir "${exp_root}/${dataset}/${exp_name}" \
--dataset_dir ${dataset_dir}
# >> ./logs/${dataset}/${exp_name}.log 2>&1