mirror of
https://github.com/jasonppy/VoiceCraft.git
synced 2025-06-05 21:49:11 +02:00
extraction,training,data,weights
This commit is contained in:
59
README.md
59
README.md
@ -1,7 +1,7 @@
|
|||||||
# VoiceCraft: Zero-Shot Speech Editing and Text-to-Speech in the Wild
|
# VoiceCraft: Zero-Shot Speech Editing and Text-to-Speech in the Wild
|
||||||
[Demo](https://jasonppy.github.io/VoiceCraft_web) [Paper](https://jasonppy.github.io/assets/pdfs/VoiceCraft.pdf)
|
[Demo](https://jasonppy.github.io/VoiceCraft_web) [Paper](https://jasonppy.github.io/assets/pdfs/VoiceCraft.pdf)
|
||||||
|
|
||||||
TL;DR:
|
### TL;DR
|
||||||
VoiceCraft is a token infilling neural codec language model, that achieves state-of-the-art performance on both **speech editing** and **zero-shot text-to-speech (TTS)** on in-the-wild data including audiobooks, internet videos, and podcasts.
|
VoiceCraft is a token infilling neural codec language model, that achieves state-of-the-art performance on both **speech editing** and **zero-shot text-to-speech (TTS)** on in-the-wild data including audiobooks, internet videos, and podcasts.
|
||||||
|
|
||||||
To clone or edit an unseen voice, VoiceCraft needs only a few seconds of reference.
|
To clone or edit an unseen voice, VoiceCraft needs only a few seconds of reference.
|
||||||
@ -12,22 +12,25 @@ The TODOs left will be completed by the end of March 2024.
|
|||||||
- [x] Codebase upload
|
- [x] Codebase upload
|
||||||
- [x] Environment setup
|
- [x] Environment setup
|
||||||
- [x] Inference demo for speech editing and TTS
|
- [x] Inference demo for speech editing and TTS
|
||||||
- [ ] Upload model weights
|
- [x] Training guidance
|
||||||
- [ ] Training guidance
|
- [x] Upload the RealEdit dataset and training manifest
|
||||||
- [ ] Upload the RealEdit dataset
|
- [ ] Upload model weights (encodec weights are up)
|
||||||
|
|
||||||
|
|
||||||
## Environment setup
|
## Environment setup
|
||||||
```bash
|
```bash
|
||||||
conda create -n voicecraft python=3.9.16
|
conda create -n voicecraft python=3.9.16
|
||||||
conda activate voicecraft
|
conda activate voicecraft
|
||||||
|
|
||||||
pip install torch==2.0.1 torchaudio==2.0.2 # this assumes your system is compatible with CUDA 11.7, otherwise checkout https://pytorch.org/get-started/previous-versions/#v201
|
pip install torch==2.0.1 # this assumes your system is compatible with CUDA 11.7, otherwise checkout https://pytorch.org/get-started/previous-versions/#v201
|
||||||
apt-get install ffmpeg # if you don't already have ffmpeg installed
|
apt-get install ffmpeg # if you don't already have ffmpeg installed
|
||||||
pip install -e git+https://github.com/facebookresearch/audiocraft.git@c5157b5bf14bf83449c17ea1eeb66c19fb4bc7f0#egg=audiocraft
|
pip install -e git+https://github.com/facebookresearch/audiocraft.git@c5157b5bf14bf83449c17ea1eeb66c19fb4bc7f0#egg=audiocraft
|
||||||
apt-get install espeak-ng # backend for the phonemizer installed below
|
apt-get install espeak-ng # backend for the phonemizer installed below
|
||||||
|
pip install tensorboard=2.16.2
|
||||||
pip install phonemizer==3.2.1
|
pip install phonemizer==3.2.1
|
||||||
pip install tensorboard
|
pip install torchaudio==2.0.2
|
||||||
pip install datasets==2.12.0
|
pip install datasets==2.16.0
|
||||||
|
pip install torchmetrics==0.11.1
|
||||||
# install MFA for getting forced-alignment, this could take a few minutes
|
# install MFA for getting forced-alignment, this could take a few minutes
|
||||||
conda install -c conda-forge montreal-forced-aligner=2.2.17 openfst=1.8.2 kaldi=5.5.1068
|
conda install -c conda-forge montreal-forced-aligner=2.2.17 openfst=1.8.2 kaldi=5.5.1068
|
||||||
# conda install pocl # above gives an warning for installing pocl, not sure if really need this
|
# conda install pocl # above gives an warning for installing pocl, not sure if really need this
|
||||||
@ -36,9 +39,51 @@ conda install -c conda-forge montreal-forced-aligner=2.2.17 openfst=1.8.2 kaldi=
|
|||||||
conda install -n voicecraft ipykernel --update-deps --force-reinstall
|
conda install -n voicecraft ipykernel --update-deps --force-reinstall
|
||||||
```
|
```
|
||||||
|
|
||||||
|
If you have encountered version issues when running things, checkout [environment.yml](./environment.yml) for exact matching.
|
||||||
|
|
||||||
## Inference Examples
|
## Inference Examples
|
||||||
Checkout [`inference_speech_editing.ipynb`](./inference_speech_editing.ipynb) and [`inference_tts.ipynb`](./inference_tts.ipynb)
|
Checkout [`inference_speech_editing.ipynb`](./inference_speech_editing.ipynb) and [`inference_tts.ipynb`](./inference_tts.ipynb)
|
||||||
|
|
||||||
|
## Training
|
||||||
|
To train an VoiceCraft model, you need to prepare the following parts:
|
||||||
|
1. utterances and their transcripts
|
||||||
|
2. encode the utterances into codes using e.g. Encodec
|
||||||
|
3. convert transcripts into phoneme sequence, and a phoneme set (we named it vocab.txt)
|
||||||
|
4. manifest (i.e. metadata)
|
||||||
|
|
||||||
|
Step 1,2,3 are handled in [./data/phonemize_encodec_encode_hf.py](./data/phonemize_encodec_encode_hf.py), where
|
||||||
|
1. Gigaspeech is downloaded through HuggingFace. Note that you need to sign an agreement in order to download the dataset (it needs your auth token)
|
||||||
|
2. phoneme sequence and encodec codes are also extracted using the script.
|
||||||
|
|
||||||
|
An example run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
conda activate voicecraft
|
||||||
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
cd ./data
|
||||||
|
python phonemize_encodec_encode_hf.py \
|
||||||
|
--dataset_size xs \
|
||||||
|
--download_to path/to/store_huggingface_downloads \
|
||||||
|
--save_dir path/to/store_extracted_codes_and_phonemes \
|
||||||
|
--encodec_model_path path/to/encodec_model \
|
||||||
|
--mega_batch_size 120 \
|
||||||
|
--batch_size 32 \
|
||||||
|
--max_len 30000
|
||||||
|
```
|
||||||
|
where encodec_model_path is avaliable [here](https://huggingface.co/pyp1/VoiceCraft). This model is trained on Gigaspeech XL, it has 56M parameters, 4 codebooks, each codebook has 2048 codes. Details are described in our [paper](https://jasonppy.github.io/assets/pdfs/VoiceCraft.pdf). If you encounter OOM during extraction, try decrease the batch_size and/or max_len.
|
||||||
|
The extracted codes, phonemes, and vocab.txt will be stored at `path/to/store_extracted_codes_and_phonemes/${dataset_size}/{encodec_16khz_4codebooks,phonemes,vocab.txt}`.
|
||||||
|
|
||||||
|
As for manifest, please download train.txt and validation.txt from [here](https://huggingface.co/datasets/pyp1/VoiceCraft_RealEdit/tree/main), and put them under `path/to/store_extracted_codes_and_phonemes/manifest/`. Please also download vocab.txt from [here](https://huggingface.co/datasets/pyp1/VoiceCraft_RealEdit/tree/main) if you want to use our pretrained VoiceCraft model (so that the phoneme-to-token matching is the same).
|
||||||
|
|
||||||
|
Now, you are good to start training!
|
||||||
|
|
||||||
|
```bash
|
||||||
|
conda activate voicecraft
|
||||||
|
cd ./z_scripts
|
||||||
|
bash e830M.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## License
|
## License
|
||||||
The codebase is under CC BY-NC-SA 4.0 ([LICENSE-CODE](./LICENSE-CODE)), and the model weights are under Coqui Public Model License 1.0.0 ([LICENSE-MODEL](./LICENSE-MODEL)). Note that we use some of the code from other repository that are under different licenses: `./models/codebooks_patterns.py` is under MIT license; `./models/modules`, `./steps/optim.py`, `data/tokenizer.py` are under Apache License, Version 2.0; the phonemizer we used is under GNU 3.0 License. For drop-in replacement of the phonemizer (i.e. text to IPA phoneme mapping), try [g2p](https://github.com/roedoejet/g2p) (MIT License) or [OpenPhonemizer](https://github.com/NeuralVox/OpenPhonemizer) (BSD-3-Clause Clear), although these are not tested.
|
The codebase is under CC BY-NC-SA 4.0 ([LICENSE-CODE](./LICENSE-CODE)), and the model weights are under Coqui Public Model License 1.0.0 ([LICENSE-MODEL](./LICENSE-MODEL)). Note that we use some of the code from other repository that are under different licenses: `./models/codebooks_patterns.py` is under MIT license; `./models/modules`, `./steps/optim.py`, `data/tokenizer.py` are under Apache License, Version 2.0; the phonemizer we used is under GNU 3.0 License. For drop-in replacement of the phonemizer (i.e. text to IPA phoneme mapping), try [g2p](https://github.com/roedoejet/g2p) (MIT License) or [OpenPhonemizer](https://github.com/NeuralVox/OpenPhonemizer) (BSD-3-Clause Clear), although these are not tested.
|
||||||
|
|
||||||
|
@ -1,160 +0,0 @@
|
|||||||
import argparse
|
|
||||||
def parse_args():
|
|
||||||
parser = argparse.ArgumentParser(description="encode the librilight dataset using encodec model")
|
|
||||||
parser.add_argument("--manifest_root", type=str, default="/home/pyp/audiocraft/egs/gigaspeech", help="this the dir of the audiocraft manifest!")
|
|
||||||
parser.add_argument('--audio_dir', type=str, default="/data/scratch/pyp/datasets/gigaspeech_flac", help="Path dirs of the flac audio files")
|
|
||||||
parser.add_argument('--save_dir', type=str, default="/data/scratch/pyp/datasets/gigaspeech_phn_enc_manifest/xl", help="path to the manifest, phonemes, and encodec codes dirs")
|
|
||||||
parser.add_argument('--encodec_model_path', type=str, default="/data/scratch/pyp/exp_pyp/audiocraft/encodec/xps/6f79c6a8/checkpoint.th")
|
|
||||||
parser.add_argument('--n_workers', type=int, default=32, help="Number of parallel worker processes")
|
|
||||||
parser.add_argument('--batch_size', type=int, default=64, help="batch size for encodec encoding, decrease it if OOM. This is the sum of batch size *over each gpu*, so increase it if you are using more gpus")
|
|
||||||
parser.add_argument('--model_sr', type=int, default=16000, help='encodec input audio sample rate')
|
|
||||||
parser.add_argument('--downsample_rate', type=int, default=320, help='encodec downsample rate')
|
|
||||||
parser.add_argument('--model_code_sr', type=int, default=50, help='encodec model code sample rate')
|
|
||||||
parser.add_argument('--len_cap', type=float, default=35.0, help='will drop audios that are longer than this number')
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import logging
|
|
||||||
formatter = (
|
|
||||||
"%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
|
|
||||||
)
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
|
||||||
|
|
||||||
import os
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torchaudio
|
|
||||||
import tqdm
|
|
||||||
import time
|
|
||||||
|
|
||||||
args = parse_args()
|
|
||||||
|
|
||||||
manifest_dir = args.manifest_root # this dir is scp-ed
|
|
||||||
audio_dir = args.audio_dir # this is scp-ed flac dir
|
|
||||||
encodec_signature = args.encodec_model_path.split("/")[-2]
|
|
||||||
save_codes_dir = os.path.join(args.save_dir, f"encodec_16khz_{encodec_signature}")
|
|
||||||
os.makedirs(save_codes_dir, exist_ok=True)
|
|
||||||
|
|
||||||
|
|
||||||
# model_sr = 16000
|
|
||||||
# downsample_rate = 320
|
|
||||||
# model_code_sr = 50
|
|
||||||
def sort_by_audio_len(lens):
|
|
||||||
inds = np.argsort(lens).tolist()
|
|
||||||
logging.info(f"longest: {lens[inds[-1]]/args.downsample_rate} encodec codes, {lens[inds[-1]]/args.model_sr:.2f} sec.")
|
|
||||||
logging.info(f"shortest: {lens[inds[0]]/args.downsample_rate} encodec codes, {lens[inds[0]]/args.model_sr:.2f} sec.")
|
|
||||||
logging.info(f"median: {lens[inds[len(inds)//2]]/args.downsample_rate} encodec codes, {lens[inds[len(inds)//2]]/args.model_sr:.2f} sec.")
|
|
||||||
logging.info(f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]/args.downsample_rate} encodec codes, {lens[inds[int(len(inds)*0.95)]]/args.model_sr:.2f} sec.")
|
|
||||||
return inds[::-1]
|
|
||||||
|
|
||||||
def write_array_to_txt_file(array, filename):
|
|
||||||
with open(filename, 'w') as f:
|
|
||||||
for a in array[:-1]:
|
|
||||||
f.write(' '.join(map(str, a))+'\n')
|
|
||||||
f.write(' '.join(map(str, array[-1])))
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class mydataset(torch.utils.data.Dataset):
|
|
||||||
def __init__(self, split):
|
|
||||||
super().__init__()
|
|
||||||
# self.data = gs[split]
|
|
||||||
self.split = split
|
|
||||||
self.audio_root = audio_dir
|
|
||||||
manifest_fn = os.path.join(manifest_dir, split+".txt")
|
|
||||||
with open(manifest_fn, "r") as rf:
|
|
||||||
self.data = [l.strip().split("\t") for l in rf.readlines()]
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.data)
|
|
||||||
def __getitem__(self, ind):
|
|
||||||
try:
|
|
||||||
afn = self.data[ind][0]
|
|
||||||
fn = os.path.join(self.audio_root, afn)
|
|
||||||
audio, sr = torchaudio.load(fn)
|
|
||||||
assert sr == args.model_sr, sr
|
|
||||||
except Exception as e:
|
|
||||||
logging.info(f"{e}")
|
|
||||||
return None, None, None
|
|
||||||
assert audio.ndim==2 and audio.shape[0] == 1, audio.shape
|
|
||||||
return audio.type(torch.float32).squeeze(0), audio.shape[-1], os.path.basename(afn).split(".")[0]
|
|
||||||
def collate(self, batch):
|
|
||||||
lens, audios, segment_ids = [], [], []
|
|
||||||
for item in batch:
|
|
||||||
if item[0] != None:
|
|
||||||
audios.append(item[0])
|
|
||||||
lens.append(item[1])
|
|
||||||
segment_ids.append(item[2])
|
|
||||||
return audios, lens, segment_ids
|
|
||||||
|
|
||||||
# load the encodec model
|
|
||||||
from audiocraft.solvers import CompressionSolver
|
|
||||||
model = CompressionSolver.model_from_checkpoint(args.encodec_model_path)
|
|
||||||
model = model.cuda()
|
|
||||||
model = model.eval()
|
|
||||||
model = torch.nn.DataParallel(model)
|
|
||||||
|
|
||||||
|
|
||||||
# setup dataloader
|
|
||||||
mega_batch_size = 2100
|
|
||||||
batch_size = args.batch_size
|
|
||||||
train_dataset = mydataset('train')
|
|
||||||
train_loader = torch.torch.utils.data.DataLoader(train_dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=train_dataset.collate)
|
|
||||||
validation_dataset = mydataset('validation')
|
|
||||||
validation_loader = torch.torch.utils.data.DataLoader(validation_dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=validation_dataset.collate)
|
|
||||||
test_dataset = mydataset('test')
|
|
||||||
test_loader = torch.torch.utils.data.DataLoader(test_dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=test_dataset.collate)
|
|
||||||
splits = ['validation', 'test', 'train']
|
|
||||||
loaders = [validation_loader, test_loader, train_loader]
|
|
||||||
# splits = ['validation'] # NOTE this is for debug, for example, see if the
|
|
||||||
# loaders = [validation_loader]
|
|
||||||
for split, loader in zip(splits, loaders):
|
|
||||||
skip = 0
|
|
||||||
logging.info(f"now processing split {split}...")
|
|
||||||
mega_n_steps = int(np.ceil(len(loader.dataset) / mega_batch_size))
|
|
||||||
# mega_n_steps = int(np.ceil(len(gs) / mega_batch_size))
|
|
||||||
logging.info(f"partition the split {split} into {mega_n_steps} parts, each has {mega_batch_size} samples")
|
|
||||||
# with open(mani_fn, "a") as mani_wf: # resume from where we failed
|
|
||||||
for m, mega_batch in enumerate(loader):
|
|
||||||
logging.info(f"====================================")
|
|
||||||
logging.info(f"====================================")
|
|
||||||
logging.info(f"now processing mega step {m+1}/{mega_n_steps}")
|
|
||||||
lengths = np.array(mega_batch[1])
|
|
||||||
sorted_inds = sort_by_audio_len(lengths)
|
|
||||||
for j in range(len(sorted_inds))[::-1]:
|
|
||||||
if lengths[sorted_inds[j]] < args.model_sr*0.2 or lengths[sorted_inds[j]] > args.model_sr*args.len_cap: # skip samples that are too short (shorter than 0.2s), or too big (bigger than 80s)
|
|
||||||
skip += 1
|
|
||||||
del sorted_inds[j]
|
|
||||||
|
|
||||||
n_steps = int(np.ceil(len(sorted_inds) / batch_size))
|
|
||||||
for n in tqdm.tqdm(range(n_steps), disable=True):
|
|
||||||
inds_used = sorted_inds[n*batch_size:(n+1)*batch_size]
|
|
||||||
wav_batch = [mega_batch[0][id] for id in inds_used]
|
|
||||||
all_lens = [mega_batch[1][id] for id in inds_used]
|
|
||||||
segment_id_batch = [mega_batch[2][id] for id in inds_used]
|
|
||||||
# print(segment_id_batch)
|
|
||||||
padded_wav = torch.nn.utils.rnn.pad_sequence(wav_batch, batch_first=True).unsqueeze(1) # [B, T] -> [B, 1, T]
|
|
||||||
with torch.no_grad():
|
|
||||||
if max(all_lens) > 300000 and len(all_lens) > 1: # NOTE decrease this (300000) if OOM, or chunk it into more than 2 forward passes
|
|
||||||
codes = []
|
|
||||||
inwav = padded_wav.cuda()
|
|
||||||
codes.append(model(inwav[:len(inwav)//2], encode=True)[0].cpu())
|
|
||||||
codes.append(model(inwav[len(inwav)//2:], encode=True)[0].cpu())
|
|
||||||
codes = torch.cat(codes, dim=0)
|
|
||||||
else:
|
|
||||||
encoded_frames = model(padded_wav.cuda(), encode=True) # wav needs to have shape [B, C, T], C is model.channels, which is 1 for the 24kHz encodec model
|
|
||||||
# logging.info(f"encoded_frames: {encoded_frames[0].shape}")
|
|
||||||
codes = encoded_frames[0].cpu()
|
|
||||||
|
|
||||||
for i, length in enumerate(all_lens):
|
|
||||||
save_fn = os.path.join(save_codes_dir, segment_id_batch[i]+".txt")
|
|
||||||
actual_len = round(length / args.downsample_rate) # 320 is downsample rate for this model
|
|
||||||
cur_code = codes[i].tolist() if type(codes) == list else codes[i, :, :actual_len].tolist()
|
|
||||||
write_array_to_txt_file(cur_code, save_fn)
|
|
||||||
|
|
||||||
# mani_wf.write(f"0\t{segment_id_batch[i]}\t{len(cur_code[0])}\n") # write to manifest file
|
|
||||||
# if i == 10:
|
|
||||||
# raise
|
|
||||||
# break
|
|
||||||
# logging.info(f"split {split} has {len(gs[split])} samples in total, skipped {skip} due to forbiden words")
|
|
||||||
logging.info(f"split {split} has {len(loader.dataset)} samples in total, skipped {skip} due to utterance being too long or too short")
|
|
||||||
# break
|
|
@ -54,8 +54,6 @@ class dataset(torch.utils.data.Dataset):
|
|||||||
y = [[int(n)+self.args.n_special for n in l] for l in encos]
|
y = [[int(n)+self.args.n_special for n in l] for l in encos]
|
||||||
else:
|
else:
|
||||||
y = [[int(n) for n in l] for l in encos]
|
y = [[int(n) for n in l] for l in encos]
|
||||||
if self.args.training_stage == 1 and not self.args.valle and not (self.args.musicgen or self.args.valle_orig):
|
|
||||||
y = y[:1]
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.info(f"loading failed for {pf} and {ef}, maybe files don't exist or are corrupted")
|
logging.info(f"loading failed for {pf} and {ef}, maybe files don't exist or are corrupted")
|
||||||
logging.info(f"error message: {e}")
|
logging.info(f"error message: {e}")
|
||||||
@ -141,15 +139,15 @@ class dataset(torch.utils.data.Dataset):
|
|||||||
if self.args.pad_x:
|
if self.args.pad_x:
|
||||||
res["x"] = torch.stack(out["x"], dim=0)
|
res["x"] = torch.stack(out["x"], dim=0)
|
||||||
else:
|
else:
|
||||||
res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=0 if self.args.sep_special_token else self.args.text_pad_token)
|
res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=self.args.text_pad_token)
|
||||||
res["x_lens"] = torch.LongTensor(out["x_len"])
|
res["x_lens"] = torch.LongTensor(out["x_len"])
|
||||||
if self.args.dynamic_batching:
|
if self.args.dynamic_batching:
|
||||||
if out['y'][0].ndim==2:
|
if out['y'][0].ndim==2:
|
||||||
res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=0 if self.args.sep_special_token else self.args.audio_pad_token)
|
res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=self.args.audio_pad_token)
|
||||||
res['y'] = res['y'].permute(1,2,0) # T B K -> B K T
|
res['y'] = res['y'].permute(1,2,0) # T B K -> B K T
|
||||||
else:
|
else:
|
||||||
assert out['y'][0].ndim==1, out['y'][0].shape
|
assert out['y'][0].ndim==1, out['y'][0].shape
|
||||||
res['y'] = torch.nn.utils.rnn.pad_sequence(out['y'], batch_first=True, padding_value=0 if self.args.sep_special_token else self.args.audio_pad_token)
|
res['y'] = torch.nn.utils.rnn.pad_sequence(out['y'], batch_first=True, padding_value=self.args.audio_pad_token)
|
||||||
else:
|
else:
|
||||||
res['y'] = torch.stack(out['y'], dim=0)
|
res['y'] = torch.stack(out['y'], dim=0)
|
||||||
res["y_lens"] = torch.LongTensor(out["y_len"])
|
res["y_lens"] = torch.LongTensor(out["y_len"])
|
||||||
|
206
data/phonemize_encodec_encode_hf.py
Normal file
206
data/phonemize_encodec_encode_hf.py
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
import argparse
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="encode the librilight dataset using encodec model")
|
||||||
|
parser.add_argument("--dataset_size", type=str, default='xs', help='sizes of gigaspeech, xs, s, m, l, xl. we use xl for VoiceCraft training, xs is good for debugging')
|
||||||
|
parser.add_argument('--download_to', type=str, default="/data/scratch/pyp/datasets/gigaspeech_debug", help="dir where you want the huggingface gigaspeech dataset to be downloaded to")
|
||||||
|
parser.add_argument('--save_dir', type=str, default="/data/scratch/pyp/datasets/gigaspeech_phn_enc_manifest_debug", help="path to the manifest, phonemes, and encodec codes dirs")
|
||||||
|
parser.add_argument('--encodec_model_path', type=str, default="/data/scratch/pyp/exp_pyp/audiocraft/encodec/xps/6f79c6a8/checkpoint.th")
|
||||||
|
parser.add_argument('--n_workers', type=int, default=4, help="Number of parallel worker processes")
|
||||||
|
parser.add_argument('--mega_batch_size', type=int, default=100, help="Number of samples in each mega batch for multiprocess dataloading")
|
||||||
|
parser.add_argument('--batch_size', type=int, default=4, help="batch size for encodec encoding, decrease it if OOM. This is the sum of batch size *over each gpu*, so increase it if you are using more gpus")
|
||||||
|
parser.add_argument('--model_sr', type=int, default=16000, help='encodec input audio sample rate')
|
||||||
|
parser.add_argument('--downsample_rate', type=int, default=320, help='encodec downsample rate')
|
||||||
|
parser.add_argument('--model_code_sr', type=int, default=50, help='encodec model code sample rate')
|
||||||
|
parser.add_argument('--len_cap', type=float, default=35.0, help='will drop audios that are longer than this number')
|
||||||
|
parser.add_argument('--max_len', type=int, default=30000, help='max length of audio in samples, if exceed, will cut a batch into half to process, decrease this number if OOM on your machine')
|
||||||
|
return parser.parse_args()
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import logging
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
|
||||||
|
)
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
import time
|
||||||
|
from datasets import load_dataset, DownloadConfig
|
||||||
|
|
||||||
|
from tokenizer import TextTokenizer, tokenize_text
|
||||||
|
|
||||||
|
# get the path
|
||||||
|
phn_save_root = os.path.join(args.save_dir, args.dataset_size, "phonemes")
|
||||||
|
codes_save_root = os.path.join(args.save_dir, args.dataset_size, "encodec_16khz_4codebooks")
|
||||||
|
vocab_fn = os.path.join(args.save_dir, args.dataset_size, "vocab.txt")
|
||||||
|
os.makedirs(phn_save_root, exist_ok=True)
|
||||||
|
os.makedirs(codes_save_root, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def sort_by_audio_len(lens):
|
||||||
|
inds = np.argsort(lens).tolist()
|
||||||
|
logging.info(f"longest: {lens[inds[-1]]*args.model_code_sr} encodec codes, {lens[inds[-1]]:.2f} sec.")
|
||||||
|
logging.info(f"shortest: {lens[inds[0]]*args.model_code_sr} encodec codes, {lens[inds[0]]:.2f} sec.")
|
||||||
|
logging.info(f"median: {lens[inds[len(inds)//2]]*args.model_code_sr} encodec codes, {lens[inds[len(inds)//2]]:.2f} sec.")
|
||||||
|
logging.info(f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]*args.model_code_sr} encodec codes, {lens[inds[int(len(inds)*0.95)]]:.2f} sec.")
|
||||||
|
return inds[::-1]
|
||||||
|
|
||||||
|
def write_array_to_txt_file(array, filename):
|
||||||
|
with open(filename, 'w') as f:
|
||||||
|
for a in array[:-1]:
|
||||||
|
f.write(' '.join(map(str, a))+'\n')
|
||||||
|
f.write(' '.join(map(str, array[-1])))
|
||||||
|
|
||||||
|
|
||||||
|
### phonemization
|
||||||
|
# load tokenizer
|
||||||
|
# load the encodec model
|
||||||
|
from audiocraft.solvers import CompressionSolver
|
||||||
|
model = CompressionSolver.model_from_checkpoint(args.encodec_model_path)
|
||||||
|
model = model.cuda()
|
||||||
|
model = model.eval()
|
||||||
|
text_tokenizer = TextTokenizer()
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/SpeechColab/GigaSpeech
|
||||||
|
# there are only four different punctuations
|
||||||
|
# need to check whether there are other < started strings
|
||||||
|
punc2sym = {" <COMMA>": ",", " <PERIOD>": ".", " <QUESTIONMARK>": "?", " <EXCLAMATIONPOINT>": "!"} # note the space in front of each punc name
|
||||||
|
gar2sym = {"<SIL>": "#%#", "<MUSIC>": "##%", "<NOISE>": "%%#", "<OTHER>":"%#%"} # so that they are savely keep as the original sym when using tokenize_text
|
||||||
|
punc2sym.update(gar2sym)
|
||||||
|
|
||||||
|
word2sym = { "h æ ʃ h ɐ ʃ p ɚ s ɛ n t": "<MUSIC>", "h æ ʃ p ɚ s ɛ n t h æ ʃ": "<SIL>", "p ɚ s ɛ n t h ɐ ʃ p ɚ s ɛ n t": "<OTHER>", "p ɚ s ɛ n t p ɚ s ɛ n t h æ ʃ": "<NOISE>"}
|
||||||
|
forbidden_words = set(['#%#', '##%', '%%#', '%#%'])
|
||||||
|
|
||||||
|
dc = DownloadConfig(cache_dir=args.download_to)
|
||||||
|
stime = time.time()
|
||||||
|
logging.info("loading the dataset...")
|
||||||
|
gs = load_dataset("speechcolab/gigaspeech", args.dataset_size, use_auth_token=True, cache_dir = args.download_to, download_config=dc)
|
||||||
|
logging.info(f"time spend on loading the dataset: {time.time() - stime:.2f} seconds")
|
||||||
|
|
||||||
|
splits = ['validation', 'test', 'train']
|
||||||
|
|
||||||
|
logging.info(f"gigaspeech dataset {args.dataset_size} info: {gs}")
|
||||||
|
logging.info(f"phonemizing...")
|
||||||
|
phn_vocab = set()
|
||||||
|
all_lens = []
|
||||||
|
|
||||||
|
# you will see a ton of [WARNING] words_mismatch.py:88......, it's not a issue
|
||||||
|
for split in tqdm.tqdm(splits):
|
||||||
|
skip = 0
|
||||||
|
logging.info(f"now processing split {split}...")
|
||||||
|
for item in tqdm.tqdm(gs[split]):
|
||||||
|
save_fn = os.path.join(phn_save_root, item['segment_id']+".txt")
|
||||||
|
text = item['text']
|
||||||
|
if sum(word in forbidden_words for word in text.split(" ")):
|
||||||
|
logging.info(f"skip {item['segment_id']}, because it contains forbiden words. It's transcript: {text}")
|
||||||
|
skip += 1
|
||||||
|
continue
|
||||||
|
for k, v in punc2sym.items():
|
||||||
|
text = text.replace(k, v)
|
||||||
|
phn = tokenize_text(text_tokenizer, text)
|
||||||
|
phn_seq = " ".join(phn)
|
||||||
|
for k, v in word2sym.items():
|
||||||
|
phn_seq = phn_seq.replace(k, v)
|
||||||
|
phn_vocab.update(phn_seq.split(" "))
|
||||||
|
all_lens.append(len(phn_seq.split(" ")))
|
||||||
|
with open(save_fn, "w") as f:
|
||||||
|
f.write(phn_seq)
|
||||||
|
logging.info(f"split {split} has {len(gs[split])} samples in total, skipped {skip} due to forbiden words")
|
||||||
|
|
||||||
|
print(f"phn vocab size: {len(list(phn_vocab))}")
|
||||||
|
print("phn sequence stats: ")
|
||||||
|
print(f"longest: {max(all_lens)}")
|
||||||
|
print(f"shortest: {min(all_lens)}")
|
||||||
|
print(f"median: {np.quantile(all_lens, 0.5)}")
|
||||||
|
print(f"95 percentile longest: {np.quantile(all_lens, 0.95)}")
|
||||||
|
print("write vocabulary to ", vocab_fn)
|
||||||
|
with open(vocab_fn, "w") as f:
|
||||||
|
for i, phn in enumerate(list(phn_vocab)):
|
||||||
|
if i < len(list(phn_vocab)) - 1:
|
||||||
|
f.write(f"{str(i)} {phn}\n")
|
||||||
|
else:
|
||||||
|
f.write(f"{str(i)} {phn}")
|
||||||
|
|
||||||
|
class mydataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, split):
|
||||||
|
super().__init__()
|
||||||
|
self.data = gs[split]
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
def __getitem__(self, ind):
|
||||||
|
try:
|
||||||
|
segment_id, audio, sr, text, begin_time, end_time = self.data[ind]['segment_id'], torch.from_numpy(self.data[ind]['audio']['array']).float(), self.data[ind]['audio']['sampling_rate'], self.data[ind]['text'], self.data[ind]['begin_time'], self.data[ind]['end_time']
|
||||||
|
except:
|
||||||
|
return None, None, None, None, None, None
|
||||||
|
|
||||||
|
return segment_id, audio, sr, text, begin_time, end_time
|
||||||
|
def collate(self, batch):
|
||||||
|
res = {'segment_id': [], "audio": [], "sr": [], "text": [], "begin_time": [], "end_time": []}
|
||||||
|
for item in batch:
|
||||||
|
if item[0] != None:
|
||||||
|
res['segment_id'].append(item[0])
|
||||||
|
res['audio'].append(item[1])
|
||||||
|
res['sr'].append(item[2])
|
||||||
|
res['text'].append(item[3])
|
||||||
|
res['begin_time'].append(item[4])
|
||||||
|
res['end_time'].append(item[5])
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
## encodec codes extraction
|
||||||
|
logging.info("encodec encoding...")
|
||||||
|
train_dataset = mydataset('train')
|
||||||
|
train_loader = torch.torch.utils.data.DataLoader(train_dataset, batch_size=args.mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=train_dataset.collate)
|
||||||
|
validation_dataset = mydataset('validation')
|
||||||
|
validation_loader = torch.torch.utils.data.DataLoader(validation_dataset, batch_size=args.mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=validation_dataset.collate)
|
||||||
|
test_dataset = mydataset('test')
|
||||||
|
test_loader = torch.torch.utils.data.DataLoader(test_dataset, batch_size=args.mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=test_dataset.collate)
|
||||||
|
splits = ['validation', 'test', 'train']
|
||||||
|
loaders = [validation_loader, test_loader, train_loader]
|
||||||
|
# splits = ['validation'] # for debug
|
||||||
|
# loaders = [validation_loader]
|
||||||
|
for split, loader in zip(splits, loaders):
|
||||||
|
skip = 0
|
||||||
|
logging.info(f"now processing split {split}...")
|
||||||
|
mega_n_steps = int(np.ceil(len(gs[split]) / args.mega_batch_size))
|
||||||
|
logging.info(f"partition the split {split} into {mega_n_steps} parts, each has {args.mega_batch_size} samples")
|
||||||
|
for m, mega_batch in enumerate(loader):
|
||||||
|
logging.info(f"====================================")
|
||||||
|
logging.info(f"====================================")
|
||||||
|
logging.info(f"now processing mega step {m+1}/{mega_n_steps}")
|
||||||
|
lengths = np.array(mega_batch['end_time']) - np.array(mega_batch['begin_time'])
|
||||||
|
sorted_inds = sort_by_audio_len(lengths)
|
||||||
|
for j in range(len(sorted_inds))[::-1]:
|
||||||
|
if lengths[sorted_inds[j]] < 0.2 or lengths[sorted_inds[j]] > args.len_cap: # skip samples that are too short (shorter than 0.2s), or too big (bigger than 80s)
|
||||||
|
skip += 1
|
||||||
|
del sorted_inds[j]
|
||||||
|
|
||||||
|
n_steps = int(np.ceil(len(sorted_inds) / args.batch_size))
|
||||||
|
for n in tqdm.tqdm(range(n_steps), disable=True):
|
||||||
|
inds_used = sorted_inds[n*args.batch_size:(n+1)*args.batch_size]
|
||||||
|
audio_batch = [mega_batch['audio'][id] for id in inds_used]
|
||||||
|
sr_batch = [mega_batch['sr'][id] for id in inds_used]
|
||||||
|
segment_id_batch = [mega_batch['segment_id'][id] for id in inds_used]
|
||||||
|
text_batch = [mega_batch['text'][id] for id in inds_used]
|
||||||
|
padded_wav = torch.nn.utils.rnn.pad_sequence(audio_batch, batch_first=True).unsqueeze(1) # [B, T] -> [B, 1, T]
|
||||||
|
all_lens = [lengths[id] for id in inds_used]
|
||||||
|
with torch.no_grad():
|
||||||
|
if max(all_lens) > args.max_len and len(all_lens) > 1: # NOTE decrease args.max_len if OOM, or chunk it into more than 2 forward passes
|
||||||
|
codes = []
|
||||||
|
inwav = padded_wav.cuda()
|
||||||
|
codes.append(model.encode(inwav[:len(inwav)//2])[0].cpu())
|
||||||
|
codes.append(model.encode(inwav[len(inwav)//2:])[0].cpu())
|
||||||
|
codes = torch.cat(codes, dim=0)
|
||||||
|
else:
|
||||||
|
encoded_frames = model.encode(padded_wav.cuda())
|
||||||
|
# logging.info(f"encoded_frames: {encoded_frames[0].shape}")
|
||||||
|
codes = encoded_frames[0].cpu()
|
||||||
|
|
||||||
|
for i, length in enumerate(all_lens):
|
||||||
|
save_fn = os.path.join(codes_save_root, segment_id_batch[i]+".txt")
|
||||||
|
actual_len = round(length * args.model_code_sr) # 320 is downsample rate for this model
|
||||||
|
cur_code = codes[i].tolist() if type(codes) == list else codes[i, :, :actual_len].tolist()
|
||||||
|
write_array_to_txt_file(cur_code, save_fn)
|
417
environment.yml
Normal file
417
environment.yml
Normal file
@ -0,0 +1,417 @@
|
|||||||
|
name: voicecraft
|
||||||
|
channels:
|
||||||
|
- conda-forge
|
||||||
|
- defaults
|
||||||
|
dependencies:
|
||||||
|
- _libgcc_mutex=0.1=conda_forge
|
||||||
|
- _openmp_mutex=4.5=2_gnu
|
||||||
|
- aom=3.8.2=h59595ed_0
|
||||||
|
- asttokens=2.4.1=pyhd8ed1ab_0
|
||||||
|
- atk-1.0=2.38.0=hd4edc92_1
|
||||||
|
- audioread=3.0.1=py39hf3d152e_1
|
||||||
|
- backcall=0.2.0=pyh9f0ad1d_0
|
||||||
|
- baumwelch=0.3.7=h00ab1b0_5
|
||||||
|
- biopython=1.79=py39hb9d737c_3
|
||||||
|
- brotli=1.1.0=hd590300_1
|
||||||
|
- brotli-bin=1.1.0=hd590300_1
|
||||||
|
- brotli-python=1.1.0=py39h3d6467e_1
|
||||||
|
- bzip2=1.0.8=hd590300_5
|
||||||
|
- ca-certificates=2024.2.2=hbcca054_0
|
||||||
|
- cairo=1.18.0=h3faef2a_0
|
||||||
|
- certifi=2024.2.2=pyhd8ed1ab_0
|
||||||
|
- cffi=1.16.0=py39h7a31438_0
|
||||||
|
- charset-normalizer=3.3.2=pyhd8ed1ab_0
|
||||||
|
- click=8.1.7=unix_pyh707e725_0
|
||||||
|
- colorama=0.4.6=pyhd8ed1ab_0
|
||||||
|
- comm=0.2.2=pyhd8ed1ab_0
|
||||||
|
- contourpy=1.2.0=py39h7633fee_0
|
||||||
|
- cycler=0.12.1=pyhd8ed1ab_0
|
||||||
|
- dataclassy=1.0.1=pyhd8ed1ab_0
|
||||||
|
- dav1d=1.2.1=hd590300_0
|
||||||
|
- debugpy=1.8.1=py39h3d6467e_0
|
||||||
|
- decorator=5.1.1=pyhd8ed1ab_0
|
||||||
|
- executing=2.0.1=pyhd8ed1ab_0
|
||||||
|
- expat=2.6.2=h59595ed_0
|
||||||
|
- ffmpeg=6.1.1=gpl_h38e077a_106
|
||||||
|
- font-ttf-dejavu-sans-mono=2.37=hab24e00_0
|
||||||
|
- font-ttf-inconsolata=3.000=h77eed37_0
|
||||||
|
- font-ttf-source-code-pro=2.038=h77eed37_0
|
||||||
|
- font-ttf-ubuntu=0.83=h77eed37_1
|
||||||
|
- fontconfig=2.14.2=h14ed4e7_0
|
||||||
|
- fonts-conda-ecosystem=1=0
|
||||||
|
- fonts-conda-forge=1=0
|
||||||
|
- fonttools=4.49.0=py39hd1e30aa_0
|
||||||
|
- freetype=2.12.1=h267a509_2
|
||||||
|
- fribidi=1.0.10=h36c2ea0_0
|
||||||
|
- gdk-pixbuf=2.42.10=h829c605_5
|
||||||
|
- gettext=0.21.1=h27087fc_0
|
||||||
|
- giflib=5.2.1=h0b41bf4_3
|
||||||
|
- gmp=6.3.0=h59595ed_1
|
||||||
|
- gnutls=3.7.9=hb077bed_0
|
||||||
|
- graphite2=1.3.13=h58526e2_1001
|
||||||
|
- graphviz=9.0.0=h78e8752_1
|
||||||
|
- greenlet=3.0.3=py39h3d6467e_0
|
||||||
|
- gtk2=2.24.33=h280cfa0_4
|
||||||
|
- gts=0.7.6=h977cf35_4
|
||||||
|
- harfbuzz=8.3.0=h3d44ed6_0
|
||||||
|
- hdbscan=0.8.33=py39h44dd56e_4
|
||||||
|
- icu=73.2=h59595ed_0
|
||||||
|
- idna=3.6=pyhd8ed1ab_0
|
||||||
|
- importlib-metadata=7.0.2=pyha770c72_0
|
||||||
|
- importlib-resources=6.3.0=pyhd8ed1ab_0
|
||||||
|
- importlib_metadata=7.0.2=hd8ed1ab_0
|
||||||
|
- importlib_resources=6.3.0=pyhd8ed1ab_0
|
||||||
|
- ipykernel=6.29.3=pyhd33586a_0
|
||||||
|
- jedi=0.19.1=pyhd8ed1ab_0
|
||||||
|
- joblib=1.3.2=pyhd8ed1ab_0
|
||||||
|
- jupyter_client=8.6.1=pyhd8ed1ab_0
|
||||||
|
- jupyter_core=5.7.2=py39hf3d152e_0
|
||||||
|
- kaldi=5.5.1068=cpu_h31769b2_2
|
||||||
|
- keyutils=1.6.1=h166bdaf_0
|
||||||
|
- kiwisolver=1.4.5=py39h7633fee_1
|
||||||
|
- kneed=0.8.5=pyhd8ed1ab_0
|
||||||
|
- krb5=1.21.2=h659d440_0
|
||||||
|
- lame=3.100=h166bdaf_1003
|
||||||
|
- lazy_loader=0.3=pyhd8ed1ab_0
|
||||||
|
- lcms2=2.16=hb7c19ff_0
|
||||||
|
- ld_impl_linux-64=2.40=h41732ed_0
|
||||||
|
- lerc=4.0.0=h27087fc_0
|
||||||
|
- libabseil=20240116.1=cxx17_h59595ed_2
|
||||||
|
- libass=0.17.1=h8fe9dca_1
|
||||||
|
- libblas=3.9.0=21_linux64_openblas
|
||||||
|
- libbrotlicommon=1.1.0=hd590300_1
|
||||||
|
- libbrotlidec=1.1.0=hd590300_1
|
||||||
|
- libbrotlienc=1.1.0=hd590300_1
|
||||||
|
- libcblas=3.9.0=21_linux64_openblas
|
||||||
|
- libclang-cpp15=15.0.7=default_hb11cfb5_4
|
||||||
|
- libdeflate=1.19=hd590300_0
|
||||||
|
- libdrm=2.4.120=hd590300_0
|
||||||
|
- libedit=3.1.20191231=he28a2e2_2
|
||||||
|
- libexpat=2.6.2=h59595ed_0
|
||||||
|
- libffi=3.4.2=h7f98852_5
|
||||||
|
- libflac=1.4.3=h59595ed_0
|
||||||
|
- libgcc-ng=13.2.0=h807b86a_5
|
||||||
|
- libgd=2.3.3=h119a65a_9
|
||||||
|
- libgfortran-ng=13.2.0=h69a702a_5
|
||||||
|
- libgfortran5=13.2.0=ha4646dd_5
|
||||||
|
- libglib=2.80.0=hf2295e7_0
|
||||||
|
- libgomp=13.2.0=h807b86a_5
|
||||||
|
- libhwloc=2.9.3=default_h554bfaf_1009
|
||||||
|
- libiconv=1.17=hd590300_2
|
||||||
|
- libidn2=2.3.7=hd590300_0
|
||||||
|
- libjpeg-turbo=3.0.0=hd590300_1
|
||||||
|
- liblapack=3.9.0=21_linux64_openblas
|
||||||
|
- liblapacke=3.9.0=21_linux64_openblas
|
||||||
|
- libllvm14=14.0.6=hcd5def8_4
|
||||||
|
- libllvm15=15.0.7=hb3ce162_4
|
||||||
|
- libllvmspirv15=15.0.0=h0cdce71_1
|
||||||
|
- libnsl=2.0.1=hd590300_0
|
||||||
|
- libogg=1.3.4=h7f98852_1
|
||||||
|
- libopenblas=0.3.26=pthreads_h413a1c8_0
|
||||||
|
- libopenvino=2024.0.0=h2e90f83_1
|
||||||
|
- libopenvino-auto-batch-plugin=2024.0.0=hd5fc58b_1
|
||||||
|
- libopenvino-auto-plugin=2024.0.0=hd5fc58b_1
|
||||||
|
- libopenvino-hetero-plugin=2024.0.0=h3ecfda7_1
|
||||||
|
- libopenvino-intel-cpu-plugin=2024.0.0=h2e90f83_1
|
||||||
|
- libopenvino-intel-gpu-plugin=2024.0.0=h2e90f83_1
|
||||||
|
- libopenvino-ir-frontend=2024.0.0=h3ecfda7_1
|
||||||
|
- libopenvino-onnx-frontend=2024.0.0=h757c851_1
|
||||||
|
- libopenvino-paddle-frontend=2024.0.0=h757c851_1
|
||||||
|
- libopenvino-pytorch-frontend=2024.0.0=h59595ed_1
|
||||||
|
- libopenvino-tensorflow-frontend=2024.0.0=hca94c1a_1
|
||||||
|
- libopenvino-tensorflow-lite-frontend=2024.0.0=h59595ed_1
|
||||||
|
- libopus=1.3.1=h7f98852_1
|
||||||
|
- libpciaccess=0.18=hd590300_0
|
||||||
|
- libpng=1.6.43=h2797004_0
|
||||||
|
- libpq=16.2=h33b98f1_0
|
||||||
|
- libprotobuf=4.25.3=h08a7969_0
|
||||||
|
- librosa=0.10.1=pyhd8ed1ab_0
|
||||||
|
- librsvg=2.56.3=he3f83f7_1
|
||||||
|
- libsndfile=1.2.2=hc60ed4a_1
|
||||||
|
- libsodium=1.0.18=h36c2ea0_1
|
||||||
|
- libsqlite=3.45.2=h2797004_0
|
||||||
|
- libstdcxx-ng=13.2.0=h7e041cc_5
|
||||||
|
- libtasn1=4.19.0=h166bdaf_0
|
||||||
|
- libtiff=4.6.0=ha9c0a0a_2
|
||||||
|
- libunistring=0.9.10=h7f98852_0
|
||||||
|
- libuuid=2.38.1=h0b41bf4_0
|
||||||
|
- libva=2.21.0=hd590300_0
|
||||||
|
- libvorbis=1.3.7=h9c3ff4c_0
|
||||||
|
- libvpx=1.14.0=h59595ed_0
|
||||||
|
- libwebp=1.3.2=h658648e_1
|
||||||
|
- libwebp-base=1.3.2=hd590300_0
|
||||||
|
- libxcb=1.15=h0b41bf4_0
|
||||||
|
- libxcrypt=4.4.36=hd590300_1
|
||||||
|
- libxml2=2.12.5=h232c23b_0
|
||||||
|
- libzlib=1.2.13=hd590300_5
|
||||||
|
- llvm-spirv-15=15.0.0=h0cdce71_1
|
||||||
|
- mad=0.15.1b=h9c3ff4c_1
|
||||||
|
- markdown-it-py=3.0.0=pyhd8ed1ab_0
|
||||||
|
- matplotlib-base=3.8.3=py39he9076e7_0
|
||||||
|
- matplotlib-inline=0.1.6=pyhd8ed1ab_0
|
||||||
|
- mdurl=0.1.2=pyhd8ed1ab_0
|
||||||
|
- montreal-forced-aligner=2.2.17=pyhd8ed1ab_0
|
||||||
|
- mpg123=1.32.4=h59595ed_0
|
||||||
|
- msgpack-python=1.0.7=py39h7633fee_0
|
||||||
|
- munkres=1.1.4=pyh9f0ad1d_0
|
||||||
|
- ncurses=6.4=h59595ed_2
|
||||||
|
- nest-asyncio=1.6.0=pyhd8ed1ab_0
|
||||||
|
- nettle=3.9.1=h7ab15ed_0
|
||||||
|
- ngram=1.3.14=h924138e_2
|
||||||
|
- numba=0.59.0=py39h615d6bd_1
|
||||||
|
- numpy=1.26.4=py39h474f0d3_0
|
||||||
|
- ocl-icd=2.3.2=hd590300_0
|
||||||
|
- openfst=1.8.2=h924138e_2
|
||||||
|
- openh264=2.4.1=h59595ed_0
|
||||||
|
- openjpeg=2.5.2=h488ebb8_0
|
||||||
|
- openssl=3.2.1=hd590300_0
|
||||||
|
- p11-kit=0.24.1=hc5aa10d_0
|
||||||
|
- packaging=24.0=pyhd8ed1ab_0
|
||||||
|
- pandas=2.2.1=py39hddac248_0
|
||||||
|
- pango=1.52.1=ha41ecd1_0
|
||||||
|
- parso=0.8.3=pyhd8ed1ab_0
|
||||||
|
- patsy=0.5.6=pyhd8ed1ab_0
|
||||||
|
- pcre2=10.43=hcad00b1_0
|
||||||
|
- pexpect=4.9.0=pyhd8ed1ab_0
|
||||||
|
- pgvector-python=0.2.5=pyhe093146_0
|
||||||
|
- pickleshare=0.7.5=py_1003
|
||||||
|
- pillow=10.2.0=py39had0adad_0
|
||||||
|
- pip=24.0=pyhd8ed1ab_0
|
||||||
|
- pixman=0.43.2=h59595ed_0
|
||||||
|
- platformdirs=4.2.0=pyhd8ed1ab_0
|
||||||
|
- pocl=5.0=h03a6ac1_2
|
||||||
|
- pocl-core=5.0=hdaecddf_2
|
||||||
|
- pocl-cpu=5.0=he901f76_2
|
||||||
|
- pocl-cpu-minimal=5.0=h5ccd973_2
|
||||||
|
- pocl-cuda=5.0=hdaecddf_2
|
||||||
|
- pocl-remote=5.0=h5ccd973_2
|
||||||
|
- pooch=1.8.1=pyhd8ed1ab_0
|
||||||
|
- postgresql=16.2=h7387d8b_0
|
||||||
|
- prompt-toolkit=3.0.42=pyha770c72_0
|
||||||
|
- prompt_toolkit=3.0.42=hd8ed1ab_0
|
||||||
|
- psutil=5.9.8=py39hd1e30aa_0
|
||||||
|
- psycopg2=2.9.9=py39h89197e3_0
|
||||||
|
- pthread-stubs=0.4=h36c2ea0_1001
|
||||||
|
- ptyprocess=0.7.0=pyhd3deb0d_0
|
||||||
|
- pugixml=1.14=h59595ed_0
|
||||||
|
- pure_eval=0.2.2=pyhd8ed1ab_0
|
||||||
|
- pycparser=2.21=pyhd8ed1ab_0
|
||||||
|
- pygments=2.17.2=pyhd8ed1ab_0
|
||||||
|
- pyparsing=3.1.2=pyhd8ed1ab_0
|
||||||
|
- pysocks=1.7.1=pyha2e5f31_6
|
||||||
|
- pysoundfile=0.12.1=pypyhd8ed1ab_1
|
||||||
|
- python=3.9.18=h0755675_1_cpython
|
||||||
|
- python-tzdata=2024.1=pyhd8ed1ab_0
|
||||||
|
- python_abi=3.9=4_cp39
|
||||||
|
- pytz=2024.1=pyhd8ed1ab_0
|
||||||
|
- pyyaml=6.0.1=py39hd1e30aa_1
|
||||||
|
- pyzmq=25.1.2=py39h8c080ef_0
|
||||||
|
- readline=8.2=h8228510_1
|
||||||
|
- requests=2.31.0=pyhd8ed1ab_0
|
||||||
|
- rich=13.7.1=pyhd8ed1ab_0
|
||||||
|
- rich-click=1.7.4=pyhd8ed1ab_0
|
||||||
|
- scikit-learn=1.2.2=py39hc236052_2
|
||||||
|
- scipy=1.12.0=py39h474f0d3_2
|
||||||
|
- seaborn=0.13.2=hd8ed1ab_0
|
||||||
|
- seaborn-base=0.13.2=pyhd8ed1ab_0
|
||||||
|
- setuptools=69.2.0=pyhd8ed1ab_0
|
||||||
|
- six=1.16.0=pyh6c4a22f_0
|
||||||
|
- snappy=1.1.10=h9fff704_0
|
||||||
|
- sox=14.4.2=ha5cc309_1018
|
||||||
|
- soxr=0.1.3=h0b41bf4_3
|
||||||
|
- soxr-python=0.3.7=py39h44dd56e_0
|
||||||
|
- sqlalchemy=2.0.28=py39hd1e30aa_0
|
||||||
|
- sqlite=3.45.2=h2c6b66d_0
|
||||||
|
- stack_data=0.6.2=pyhd8ed1ab_0
|
||||||
|
- statsmodels=0.14.1=py39h44dd56e_0
|
||||||
|
- svt-av1=1.8.0=h59595ed_0
|
||||||
|
- tbb=2021.11.0=h00ab1b0_1
|
||||||
|
- threadpoolctl=3.3.0=pyhc1e730c_0
|
||||||
|
- tk=8.6.13=noxft_h4845f30_101
|
||||||
|
- tornado=6.4=py39hd1e30aa_0
|
||||||
|
- tqdm=4.66.2=pyhd8ed1ab_0
|
||||||
|
- traitlets=5.14.2=pyhd8ed1ab_0
|
||||||
|
- typing-extensions=4.10.0=hd8ed1ab_0
|
||||||
|
- typing_extensions=4.10.0=pyha770c72_0
|
||||||
|
- tzcode=2024a=h3f72095_0
|
||||||
|
- tzdata=2024a=h0c530f3_0
|
||||||
|
- unicodedata2=15.1.0=py39hd1e30aa_0
|
||||||
|
- urllib3=2.2.1=pyhd8ed1ab_0
|
||||||
|
- wcwidth=0.2.13=pyhd8ed1ab_0
|
||||||
|
- wheel=0.42.0=pyhd8ed1ab_0
|
||||||
|
- x264=1!164.3095=h166bdaf_2
|
||||||
|
- x265=3.5=h924138e_3
|
||||||
|
- xorg-fixesproto=5.0=h7f98852_1002
|
||||||
|
- xorg-kbproto=1.0.7=h7f98852_1002
|
||||||
|
- xorg-libice=1.1.1=hd590300_0
|
||||||
|
- xorg-libsm=1.2.4=h7391055_0
|
||||||
|
- xorg-libx11=1.8.7=h8ee46fc_0
|
||||||
|
- xorg-libxau=1.0.11=hd590300_0
|
||||||
|
- xorg-libxdmcp=1.1.3=h7f98852_0
|
||||||
|
- xorg-libxext=1.3.4=h0b41bf4_2
|
||||||
|
- xorg-libxfixes=5.0.3=h7f98852_1004
|
||||||
|
- xorg-libxrender=0.9.11=hd590300_0
|
||||||
|
- xorg-renderproto=0.11.1=h7f98852_1002
|
||||||
|
- xorg-xextproto=7.3.0=h0b41bf4_1003
|
||||||
|
- xorg-xproto=7.0.31=h7f98852_1007
|
||||||
|
- xz=5.2.6=h166bdaf_0
|
||||||
|
- yaml=0.2.5=h7f98852_2
|
||||||
|
- zeromq=4.3.5=h59595ed_1
|
||||||
|
- zipp=3.17.0=pyhd8ed1ab_0
|
||||||
|
- zlib=1.2.13=hd590300_5
|
||||||
|
- zstd=1.5.5=hfc55251_0
|
||||||
|
- pip:
|
||||||
|
- absl-py==2.1.0
|
||||||
|
- aiofiles==23.2.1
|
||||||
|
- aiohttp==3.9.3
|
||||||
|
- aiosignal==1.3.1
|
||||||
|
- altair==5.2.0
|
||||||
|
- antlr4-python3-runtime==4.9.3
|
||||||
|
- anyio==4.3.0
|
||||||
|
- async-timeout==4.0.3
|
||||||
|
- attrs==23.2.0
|
||||||
|
- av==11.0.0
|
||||||
|
- babel==2.14.0
|
||||||
|
- beautifulsoup4==4.12.3
|
||||||
|
- bibtexparser==2.0.0b7
|
||||||
|
- bleach==6.1.0
|
||||||
|
- blis==0.7.11
|
||||||
|
- catalogue==2.0.10
|
||||||
|
- clldutils==3.22.2
|
||||||
|
- cloudpickle==3.0.0
|
||||||
|
- cmake==3.28.3
|
||||||
|
- colorlog==6.8.2
|
||||||
|
- confection==0.1.4
|
||||||
|
- csvw==3.3.0
|
||||||
|
- cymem==2.0.8
|
||||||
|
- cython==0.29.37
|
||||||
|
- datasets==2.16.0
|
||||||
|
- defusedxml==0.7.1
|
||||||
|
- demucs==4.0.1
|
||||||
|
- dill==0.3.6
|
||||||
|
- dlinfo==1.2.1
|
||||||
|
- docopt==0.6.2
|
||||||
|
- dora-search==0.1.12
|
||||||
|
- einops==0.7.0
|
||||||
|
- encodec==0.1.1
|
||||||
|
- exceptiongroup==1.2.0
|
||||||
|
- fastapi==0.110.0
|
||||||
|
- fastjsonschema==2.19.1
|
||||||
|
- ffmpy==0.3.2
|
||||||
|
- filelock==3.13.1
|
||||||
|
- flashy==0.0.2
|
||||||
|
- frozenlist==1.4.1
|
||||||
|
- fsspec==2023.10.0
|
||||||
|
- gradio==3.50.2
|
||||||
|
- gradio-client==0.6.1
|
||||||
|
- grpcio==1.62.1
|
||||||
|
- h11==0.14.0
|
||||||
|
- httpcore==1.0.4
|
||||||
|
- httpx==0.27.0
|
||||||
|
- huggingface-hub==0.21.4
|
||||||
|
- hydra-colorlog==1.2.0
|
||||||
|
- hydra-core==1.3.2
|
||||||
|
- ipython==8.12.3
|
||||||
|
- isodate==0.6.1
|
||||||
|
- jinja2==3.1.3
|
||||||
|
- jsonschema==4.21.1
|
||||||
|
- jsonschema-specifications==2023.12.1
|
||||||
|
- julius==0.2.7
|
||||||
|
- jupyterlab-pygments==0.3.0
|
||||||
|
- lameenc==1.7.0
|
||||||
|
- langcodes==3.3.0
|
||||||
|
- language-tags==1.2.0
|
||||||
|
- lit==18.1.1
|
||||||
|
- llvmlite==0.42.0
|
||||||
|
- lxml==5.1.0
|
||||||
|
- markdown==3.5.2
|
||||||
|
- markupsafe==2.1.5
|
||||||
|
- mistune==3.0.2
|
||||||
|
- mpmath==1.3.0
|
||||||
|
- msgpack==1.0.8
|
||||||
|
- multidict==6.0.5
|
||||||
|
- multiprocess==0.70.14
|
||||||
|
- murmurhash==1.0.10
|
||||||
|
- nbclient==0.10.0
|
||||||
|
- nbconvert==7.16.3
|
||||||
|
- nbformat==5.10.3
|
||||||
|
- networkx==3.2.1
|
||||||
|
- num2words==0.5.13
|
||||||
|
- nvidia-cublas-cu11==11.10.3.66
|
||||||
|
- nvidia-cuda-cupti-cu11==11.7.101
|
||||||
|
- nvidia-cuda-nvrtc-cu11==11.7.99
|
||||||
|
- nvidia-cuda-runtime-cu11==11.7.99
|
||||||
|
- nvidia-cudnn-cu11==8.5.0.96
|
||||||
|
- nvidia-cufft-cu11==10.9.0.58
|
||||||
|
- nvidia-curand-cu11==10.2.10.91
|
||||||
|
- nvidia-cusolver-cu11==11.4.0.1
|
||||||
|
- nvidia-cusparse-cu11==11.7.4.91
|
||||||
|
- nvidia-nccl-cu11==2.14.3
|
||||||
|
- nvidia-nvtx-cu11==11.7.91
|
||||||
|
- omegaconf==2.3.0
|
||||||
|
- openunmix==1.2.1
|
||||||
|
- orjson==3.9.15
|
||||||
|
- pandocfilters==1.5.1
|
||||||
|
- pathlib-abc==0.1.1
|
||||||
|
- pathy==0.11.0
|
||||||
|
- pgvector==0.2.2
|
||||||
|
- phonemizer==3.2.1
|
||||||
|
- pipreqs==0.5.0
|
||||||
|
- praatio==6.2.0
|
||||||
|
- preshed==3.0.9
|
||||||
|
- protobuf==4.25.3
|
||||||
|
- pyarrow==15.0.2
|
||||||
|
- pyarrow-hotfix==0.6
|
||||||
|
- pydantic==1.10.14
|
||||||
|
- pydub==0.25.1
|
||||||
|
- pylatexenc==2.10
|
||||||
|
- pynini==2.1.6
|
||||||
|
- pypinyin==0.48.0
|
||||||
|
- python-dateutil==2.9.0.post0
|
||||||
|
- python-multipart==0.0.9
|
||||||
|
- rdflib==7.0.0
|
||||||
|
- referencing==0.33.0
|
||||||
|
- regex==2023.12.25
|
||||||
|
- responses==0.18.0
|
||||||
|
- retrying==1.3.4
|
||||||
|
- rfc3986==1.5.0
|
||||||
|
- rpds-py==0.18.0
|
||||||
|
- safetensors==0.4.2
|
||||||
|
- segments==2.2.1
|
||||||
|
- semantic-version==2.10.0
|
||||||
|
- sentencepiece==0.2.0
|
||||||
|
- smart-open==6.4.0
|
||||||
|
- sniffio==1.3.1
|
||||||
|
- soupsieve==2.5
|
||||||
|
- spacy==3.5.2
|
||||||
|
- spacy-legacy==3.0.12
|
||||||
|
- spacy-loggers==1.0.5
|
||||||
|
- srsly==2.4.8
|
||||||
|
- starlette==0.36.3
|
||||||
|
- submitit==1.5.1
|
||||||
|
- sympy==1.12
|
||||||
|
- tabulate==0.9.0
|
||||||
|
- tensorboard==2.16.2
|
||||||
|
- tensorboard-data-server==0.7.2
|
||||||
|
- thinc==8.1.12
|
||||||
|
- tinycss2==1.2.1
|
||||||
|
- tokenizers==0.15.2
|
||||||
|
- toolz==0.12.1
|
||||||
|
- torch==2.0.1
|
||||||
|
- torchaudio==2.0.2
|
||||||
|
- torchmetrics==0.11.1
|
||||||
|
- transformers==4.38.2
|
||||||
|
- treetable==0.2.5
|
||||||
|
- triton==2.0.0
|
||||||
|
- typer==0.7.0
|
||||||
|
- uritemplate==4.1.1
|
||||||
|
- uvicorn==0.28.0
|
||||||
|
- wasabi==1.1.2
|
||||||
|
- webencodings==0.5.1
|
||||||
|
- websockets==11.0.3
|
||||||
|
- werkzeug==3.0.1
|
||||||
|
- xformers==0.0.22
|
||||||
|
- xxhash==3.4.1
|
||||||
|
- yarg==0.1.9
|
||||||
|
- yarl==1.9.4
|
||||||
|
prefix: /home/pyp/miniconda3/envs/voicecraft
|
@ -504,7 +504,7 @@ class VoiceCraft(nn.Module):
|
|||||||
ntokens = []
|
ntokens = []
|
||||||
top10acc = []
|
top10acc = []
|
||||||
for k, (logit, target) in enumerate(zip(logits, targets)):
|
for k, (logit, target) in enumerate(zip(logits, targets)):
|
||||||
loss.append(F.cross_entropy(logit, target, reduction='mean', weight=self.class_weight.data if self.args.eog_weight!=1 else None))
|
loss.append(F.cross_entropy(logit, target, reduction='mean'))
|
||||||
top10acc.append(self.accuracy_metrics[k](logit.detach(), target))
|
top10acc.append(self.accuracy_metrics[k](logit.detach(), target))
|
||||||
ntokens.append(len(logit))
|
ntokens.append(len(logit))
|
||||||
|
|
||||||
@ -988,6 +988,8 @@ class VoiceCraft(nn.Module):
|
|||||||
for jj in range(1,self.args.n_codebooks):
|
for jj in range(1,self.args.n_codebooks):
|
||||||
logits_adjust[jj][eog_inference] = -10000
|
logits_adjust[jj][eog_inference] = -10000
|
||||||
logits_adjust[jj][self.args.empty_token] = -10000
|
logits_adjust[jj][self.args.empty_token] = -10000
|
||||||
|
if cur_num_gen <= self.args.encodec_sr // 5: # this shouldn't happen, but just in case the model stopped too early
|
||||||
|
logits_adjust[0][eog_inference] = -10000
|
||||||
##################### silence repetition handling #####################
|
##################### silence repetition handling #####################
|
||||||
if stop_repetition > 0 and prev_token in silence_tokens and consec_silence_count > stop_repetition:
|
if stop_repetition > 0 and prev_token in silence_tokens and consec_silence_count > stop_repetition:
|
||||||
if logits_adjust[0, prev_token] < 0:
|
if logits_adjust[0, prev_token] < 0:
|
||||||
@ -1237,6 +1239,8 @@ class VoiceCraft(nn.Module):
|
|||||||
for jj in range(1,self.args.n_codebooks):
|
for jj in range(1,self.args.n_codebooks):
|
||||||
logits_adjust[:,jj,eog_inference] = -10000
|
logits_adjust[:,jj,eog_inference] = -10000
|
||||||
logits_adjust[:,jj,self.args.empty_token] = -10000
|
logits_adjust[:,jj,self.args.empty_token] = -10000
|
||||||
|
if cur_num_gen <= self.args.encodec_sr // 5: # this shouldn't happen, but just in case the model stopped too early
|
||||||
|
logits_adjust[:,:,eog_inference] = -10000
|
||||||
##################### silence repetition handling #####################
|
##################### silence repetition handling #####################
|
||||||
for b in range(batch_size):
|
for b in range(batch_size):
|
||||||
prev_token = prev_tokens[b]
|
prev_token = prev_tokens[b]
|
||||||
|
@ -7,9 +7,9 @@ export WORLD_SIZE=4
|
|||||||
dataset=gigaspeech
|
dataset=gigaspeech
|
||||||
mkdir -p ./logs/${dataset}
|
mkdir -p ./logs/${dataset}
|
||||||
|
|
||||||
exp_root="/data/scratch/pyp/exp_pyp/VoiceCraft"
|
exp_root="path/to/store/exp_results"
|
||||||
exp_name=e830M
|
exp_name=e830M
|
||||||
dataset_dir="/data/scratch/pyp/datasets/gigaspeech_phn_enc_manifest/xl"
|
dataset_dir="path/to/stored_extracted_codes_and_phonemes/xl" # xs if you only extracted xs in previous step
|
||||||
encodec_codes_folder_name="encodec_16khz_4codebooks"
|
encodec_codes_folder_name="encodec_16khz_4codebooks"
|
||||||
|
|
||||||
# export CUDA_LAUNCH_BLOCKING=1 # for debugging
|
# export CUDA_LAUNCH_BLOCKING=1 # for debugging
|
||||||
@ -51,7 +51,7 @@ torchrun --nnodes=1 --rdzv-backend=c10d --rdzv-endpoint=localhost:41977 --nproc_
|
|||||||
--text_vocab_size 100 \
|
--text_vocab_size 100 \
|
||||||
--text_pad_token 100 \
|
--text_pad_token 100 \
|
||||||
--phn_folder_name "phonemes" \
|
--phn_folder_name "phonemes" \
|
||||||
--manifest_name "manifest_large16khz_lessambi" \
|
--manifest_name "manifest" \
|
||||||
--encodec_folder_name ${encodec_codes_folder_name} \
|
--encodec_folder_name ${encodec_codes_folder_name} \
|
||||||
--audio_vocab_size 2048 \
|
--audio_vocab_size 2048 \
|
||||||
--empty_token 2048 \
|
--empty_token 2048 \
|
||||||
|
Reference in New Issue
Block a user