extraction,training,data,weights
This commit is contained in:
parent
d754e9109a
commit
a129883910
59
README.md
59
README.md
|
@ -1,7 +1,7 @@
|
|||
# VoiceCraft: Zero-Shot Speech Editing and Text-to-Speech in the Wild
|
||||
[Demo](https://jasonppy.github.io/VoiceCraft_web) [Paper](https://jasonppy.github.io/assets/pdfs/VoiceCraft.pdf)
|
||||
|
||||
TL;DR:
|
||||
### TL;DR
|
||||
VoiceCraft is a token infilling neural codec language model, that achieves state-of-the-art performance on both **speech editing** and **zero-shot text-to-speech (TTS)** on in-the-wild data including audiobooks, internet videos, and podcasts.
|
||||
|
||||
To clone or edit an unseen voice, VoiceCraft needs only a few seconds of reference.
|
||||
|
@ -12,22 +12,25 @@ The TODOs left will be completed by the end of March 2024.
|
|||
- [x] Codebase upload
|
||||
- [x] Environment setup
|
||||
- [x] Inference demo for speech editing and TTS
|
||||
- [ ] Upload model weights
|
||||
- [ ] Training guidance
|
||||
- [ ] Upload the RealEdit dataset
|
||||
- [x] Training guidance
|
||||
- [x] Upload the RealEdit dataset and training manifest
|
||||
- [ ] Upload model weights (encodec weights are up)
|
||||
|
||||
|
||||
## Environment setup
|
||||
```bash
|
||||
conda create -n voicecraft python=3.9.16
|
||||
conda activate voicecraft
|
||||
|
||||
pip install torch==2.0.1 torchaudio==2.0.2 # this assumes your system is compatible with CUDA 11.7, otherwise checkout https://pytorch.org/get-started/previous-versions/#v201
|
||||
pip install torch==2.0.1 # this assumes your system is compatible with CUDA 11.7, otherwise checkout https://pytorch.org/get-started/previous-versions/#v201
|
||||
apt-get install ffmpeg # if you don't already have ffmpeg installed
|
||||
pip install -e git+https://github.com/facebookresearch/audiocraft.git@c5157b5bf14bf83449c17ea1eeb66c19fb4bc7f0#egg=audiocraft
|
||||
apt-get install espeak-ng # backend for the phonemizer installed below
|
||||
pip install tensorboard=2.16.2
|
||||
pip install phonemizer==3.2.1
|
||||
pip install tensorboard
|
||||
pip install datasets==2.12.0
|
||||
pip install torchaudio==2.0.2
|
||||
pip install datasets==2.16.0
|
||||
pip install torchmetrics==0.11.1
|
||||
# install MFA for getting forced-alignment, this could take a few minutes
|
||||
conda install -c conda-forge montreal-forced-aligner=2.2.17 openfst=1.8.2 kaldi=5.5.1068
|
||||
# conda install pocl # above gives an warning for installing pocl, not sure if really need this
|
||||
|
@ -36,9 +39,51 @@ conda install -c conda-forge montreal-forced-aligner=2.2.17 openfst=1.8.2 kaldi=
|
|||
conda install -n voicecraft ipykernel --update-deps --force-reinstall
|
||||
```
|
||||
|
||||
If you have encountered version issues when running things, checkout [environment.yml](./environment.yml) for exact matching.
|
||||
|
||||
## Inference Examples
|
||||
Checkout [`inference_speech_editing.ipynb`](./inference_speech_editing.ipynb) and [`inference_tts.ipynb`](./inference_tts.ipynb)
|
||||
|
||||
## Training
|
||||
To train an VoiceCraft model, you need to prepare the following parts:
|
||||
1. utterances and their transcripts
|
||||
2. encode the utterances into codes using e.g. Encodec
|
||||
3. convert transcripts into phoneme sequence, and a phoneme set (we named it vocab.txt)
|
||||
4. manifest (i.e. metadata)
|
||||
|
||||
Step 1,2,3 are handled in [./data/phonemize_encodec_encode_hf.py](./data/phonemize_encodec_encode_hf.py), where
|
||||
1. Gigaspeech is downloaded through HuggingFace. Note that you need to sign an agreement in order to download the dataset (it needs your auth token)
|
||||
2. phoneme sequence and encodec codes are also extracted using the script.
|
||||
|
||||
An example run:
|
||||
|
||||
```bash
|
||||
conda activate voicecraft
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
cd ./data
|
||||
python phonemize_encodec_encode_hf.py \
|
||||
--dataset_size xs \
|
||||
--download_to path/to/store_huggingface_downloads \
|
||||
--save_dir path/to/store_extracted_codes_and_phonemes \
|
||||
--encodec_model_path path/to/encodec_model \
|
||||
--mega_batch_size 120 \
|
||||
--batch_size 32 \
|
||||
--max_len 30000
|
||||
```
|
||||
where encodec_model_path is avaliable [here](https://huggingface.co/pyp1/VoiceCraft). This model is trained on Gigaspeech XL, it has 56M parameters, 4 codebooks, each codebook has 2048 codes. Details are described in our [paper](https://jasonppy.github.io/assets/pdfs/VoiceCraft.pdf). If you encounter OOM during extraction, try decrease the batch_size and/or max_len.
|
||||
The extracted codes, phonemes, and vocab.txt will be stored at `path/to/store_extracted_codes_and_phonemes/${dataset_size}/{encodec_16khz_4codebooks,phonemes,vocab.txt}`.
|
||||
|
||||
As for manifest, please download train.txt and validation.txt from [here](https://huggingface.co/datasets/pyp1/VoiceCraft_RealEdit/tree/main), and put them under `path/to/store_extracted_codes_and_phonemes/manifest/`. Please also download vocab.txt from [here](https://huggingface.co/datasets/pyp1/VoiceCraft_RealEdit/tree/main) if you want to use our pretrained VoiceCraft model (so that the phoneme-to-token matching is the same).
|
||||
|
||||
Now, you are good to start training!
|
||||
|
||||
```bash
|
||||
conda activate voicecraft
|
||||
cd ./z_scripts
|
||||
bash e830M.sh
|
||||
```
|
||||
|
||||
|
||||
## License
|
||||
The codebase is under CC BY-NC-SA 4.0 ([LICENSE-CODE](./LICENSE-CODE)), and the model weights are under Coqui Public Model License 1.0.0 ([LICENSE-MODEL](./LICENSE-MODEL)). Note that we use some of the code from other repository that are under different licenses: `./models/codebooks_patterns.py` is under MIT license; `./models/modules`, `./steps/optim.py`, `data/tokenizer.py` are under Apache License, Version 2.0; the phonemizer we used is under GNU 3.0 License. For drop-in replacement of the phonemizer (i.e. text to IPA phoneme mapping), try [g2p](https://github.com/roedoejet/g2p) (MIT License) or [OpenPhonemizer](https://github.com/NeuralVox/OpenPhonemizer) (BSD-3-Clause Clear), although these are not tested.
|
||||
|
||||
|
|
|
@ -1,160 +0,0 @@
|
|||
import argparse
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="encode the librilight dataset using encodec model")
|
||||
parser.add_argument("--manifest_root", type=str, default="/home/pyp/audiocraft/egs/gigaspeech", help="this the dir of the audiocraft manifest!")
|
||||
parser.add_argument('--audio_dir', type=str, default="/data/scratch/pyp/datasets/gigaspeech_flac", help="Path dirs of the flac audio files")
|
||||
parser.add_argument('--save_dir', type=str, default="/data/scratch/pyp/datasets/gigaspeech_phn_enc_manifest/xl", help="path to the manifest, phonemes, and encodec codes dirs")
|
||||
parser.add_argument('--encodec_model_path', type=str, default="/data/scratch/pyp/exp_pyp/audiocraft/encodec/xps/6f79c6a8/checkpoint.th")
|
||||
parser.add_argument('--n_workers', type=int, default=32, help="Number of parallel worker processes")
|
||||
parser.add_argument('--batch_size', type=int, default=64, help="batch size for encodec encoding, decrease it if OOM. This is the sum of batch size *over each gpu*, so increase it if you are using more gpus")
|
||||
parser.add_argument('--model_sr', type=int, default=16000, help='encodec input audio sample rate')
|
||||
parser.add_argument('--downsample_rate', type=int, default=320, help='encodec downsample rate')
|
||||
parser.add_argument('--model_code_sr', type=int, default=50, help='encodec model code sample rate')
|
||||
parser.add_argument('--len_cap', type=float, default=35.0, help='will drop audios that are longer than this number')
|
||||
return parser.parse_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
import logging
|
||||
formatter = (
|
||||
"%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
|
||||
)
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
import tqdm
|
||||
import time
|
||||
|
||||
args = parse_args()
|
||||
|
||||
manifest_dir = args.manifest_root # this dir is scp-ed
|
||||
audio_dir = args.audio_dir # this is scp-ed flac dir
|
||||
encodec_signature = args.encodec_model_path.split("/")[-2]
|
||||
save_codes_dir = os.path.join(args.save_dir, f"encodec_16khz_{encodec_signature}")
|
||||
os.makedirs(save_codes_dir, exist_ok=True)
|
||||
|
||||
|
||||
# model_sr = 16000
|
||||
# downsample_rate = 320
|
||||
# model_code_sr = 50
|
||||
def sort_by_audio_len(lens):
|
||||
inds = np.argsort(lens).tolist()
|
||||
logging.info(f"longest: {lens[inds[-1]]/args.downsample_rate} encodec codes, {lens[inds[-1]]/args.model_sr:.2f} sec.")
|
||||
logging.info(f"shortest: {lens[inds[0]]/args.downsample_rate} encodec codes, {lens[inds[0]]/args.model_sr:.2f} sec.")
|
||||
logging.info(f"median: {lens[inds[len(inds)//2]]/args.downsample_rate} encodec codes, {lens[inds[len(inds)//2]]/args.model_sr:.2f} sec.")
|
||||
logging.info(f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]/args.downsample_rate} encodec codes, {lens[inds[int(len(inds)*0.95)]]/args.model_sr:.2f} sec.")
|
||||
return inds[::-1]
|
||||
|
||||
def write_array_to_txt_file(array, filename):
|
||||
with open(filename, 'w') as f:
|
||||
for a in array[:-1]:
|
||||
f.write(' '.join(map(str, a))+'\n')
|
||||
f.write(' '.join(map(str, array[-1])))
|
||||
|
||||
|
||||
|
||||
class mydataset(torch.utils.data.Dataset):
|
||||
def __init__(self, split):
|
||||
super().__init__()
|
||||
# self.data = gs[split]
|
||||
self.split = split
|
||||
self.audio_root = audio_dir
|
||||
manifest_fn = os.path.join(manifest_dir, split+".txt")
|
||||
with open(manifest_fn, "r") as rf:
|
||||
self.data = [l.strip().split("\t") for l in rf.readlines()]
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
def __getitem__(self, ind):
|
||||
try:
|
||||
afn = self.data[ind][0]
|
||||
fn = os.path.join(self.audio_root, afn)
|
||||
audio, sr = torchaudio.load(fn)
|
||||
assert sr == args.model_sr, sr
|
||||
except Exception as e:
|
||||
logging.info(f"{e}")
|
||||
return None, None, None
|
||||
assert audio.ndim==2 and audio.shape[0] == 1, audio.shape
|
||||
return audio.type(torch.float32).squeeze(0), audio.shape[-1], os.path.basename(afn).split(".")[0]
|
||||
def collate(self, batch):
|
||||
lens, audios, segment_ids = [], [], []
|
||||
for item in batch:
|
||||
if item[0] != None:
|
||||
audios.append(item[0])
|
||||
lens.append(item[1])
|
||||
segment_ids.append(item[2])
|
||||
return audios, lens, segment_ids
|
||||
|
||||
# load the encodec model
|
||||
from audiocraft.solvers import CompressionSolver
|
||||
model = CompressionSolver.model_from_checkpoint(args.encodec_model_path)
|
||||
model = model.cuda()
|
||||
model = model.eval()
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
|
||||
# setup dataloader
|
||||
mega_batch_size = 2100
|
||||
batch_size = args.batch_size
|
||||
train_dataset = mydataset('train')
|
||||
train_loader = torch.torch.utils.data.DataLoader(train_dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=train_dataset.collate)
|
||||
validation_dataset = mydataset('validation')
|
||||
validation_loader = torch.torch.utils.data.DataLoader(validation_dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=validation_dataset.collate)
|
||||
test_dataset = mydataset('test')
|
||||
test_loader = torch.torch.utils.data.DataLoader(test_dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=test_dataset.collate)
|
||||
splits = ['validation', 'test', 'train']
|
||||
loaders = [validation_loader, test_loader, train_loader]
|
||||
# splits = ['validation'] # NOTE this is for debug, for example, see if the
|
||||
# loaders = [validation_loader]
|
||||
for split, loader in zip(splits, loaders):
|
||||
skip = 0
|
||||
logging.info(f"now processing split {split}...")
|
||||
mega_n_steps = int(np.ceil(len(loader.dataset) / mega_batch_size))
|
||||
# mega_n_steps = int(np.ceil(len(gs) / mega_batch_size))
|
||||
logging.info(f"partition the split {split} into {mega_n_steps} parts, each has {mega_batch_size} samples")
|
||||
# with open(mani_fn, "a") as mani_wf: # resume from where we failed
|
||||
for m, mega_batch in enumerate(loader):
|
||||
logging.info(f"====================================")
|
||||
logging.info(f"====================================")
|
||||
logging.info(f"now processing mega step {m+1}/{mega_n_steps}")
|
||||
lengths = np.array(mega_batch[1])
|
||||
sorted_inds = sort_by_audio_len(lengths)
|
||||
for j in range(len(sorted_inds))[::-1]:
|
||||
if lengths[sorted_inds[j]] < args.model_sr*0.2 or lengths[sorted_inds[j]] > args.model_sr*args.len_cap: # skip samples that are too short (shorter than 0.2s), or too big (bigger than 80s)
|
||||
skip += 1
|
||||
del sorted_inds[j]
|
||||
|
||||
n_steps = int(np.ceil(len(sorted_inds) / batch_size))
|
||||
for n in tqdm.tqdm(range(n_steps), disable=True):
|
||||
inds_used = sorted_inds[n*batch_size:(n+1)*batch_size]
|
||||
wav_batch = [mega_batch[0][id] for id in inds_used]
|
||||
all_lens = [mega_batch[1][id] for id in inds_used]
|
||||
segment_id_batch = [mega_batch[2][id] for id in inds_used]
|
||||
# print(segment_id_batch)
|
||||
padded_wav = torch.nn.utils.rnn.pad_sequence(wav_batch, batch_first=True).unsqueeze(1) # [B, T] -> [B, 1, T]
|
||||
with torch.no_grad():
|
||||
if max(all_lens) > 300000 and len(all_lens) > 1: # NOTE decrease this (300000) if OOM, or chunk it into more than 2 forward passes
|
||||
codes = []
|
||||
inwav = padded_wav.cuda()
|
||||
codes.append(model(inwav[:len(inwav)//2], encode=True)[0].cpu())
|
||||
codes.append(model(inwav[len(inwav)//2:], encode=True)[0].cpu())
|
||||
codes = torch.cat(codes, dim=0)
|
||||
else:
|
||||
encoded_frames = model(padded_wav.cuda(), encode=True) # wav needs to have shape [B, C, T], C is model.channels, which is 1 for the 24kHz encodec model
|
||||
# logging.info(f"encoded_frames: {encoded_frames[0].shape}")
|
||||
codes = encoded_frames[0].cpu()
|
||||
|
||||
for i, length in enumerate(all_lens):
|
||||
save_fn = os.path.join(save_codes_dir, segment_id_batch[i]+".txt")
|
||||
actual_len = round(length / args.downsample_rate) # 320 is downsample rate for this model
|
||||
cur_code = codes[i].tolist() if type(codes) == list else codes[i, :, :actual_len].tolist()
|
||||
write_array_to_txt_file(cur_code, save_fn)
|
||||
|
||||
# mani_wf.write(f"0\t{segment_id_batch[i]}\t{len(cur_code[0])}\n") # write to manifest file
|
||||
# if i == 10:
|
||||
# raise
|
||||
# break
|
||||
# logging.info(f"split {split} has {len(gs[split])} samples in total, skipped {skip} due to forbiden words")
|
||||
logging.info(f"split {split} has {len(loader.dataset)} samples in total, skipped {skip} due to utterance being too long or too short")
|
||||
# break
|
|
@ -54,8 +54,6 @@ class dataset(torch.utils.data.Dataset):
|
|||
y = [[int(n)+self.args.n_special for n in l] for l in encos]
|
||||
else:
|
||||
y = [[int(n) for n in l] for l in encos]
|
||||
if self.args.training_stage == 1 and not self.args.valle and not (self.args.musicgen or self.args.valle_orig):
|
||||
y = y[:1]
|
||||
except Exception as e:
|
||||
logging.info(f"loading failed for {pf} and {ef}, maybe files don't exist or are corrupted")
|
||||
logging.info(f"error message: {e}")
|
||||
|
@ -141,15 +139,15 @@ class dataset(torch.utils.data.Dataset):
|
|||
if self.args.pad_x:
|
||||
res["x"] = torch.stack(out["x"], dim=0)
|
||||
else:
|
||||
res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=0 if self.args.sep_special_token else self.args.text_pad_token)
|
||||
res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=self.args.text_pad_token)
|
||||
res["x_lens"] = torch.LongTensor(out["x_len"])
|
||||
if self.args.dynamic_batching:
|
||||
if out['y'][0].ndim==2:
|
||||
res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=0 if self.args.sep_special_token else self.args.audio_pad_token)
|
||||
res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=self.args.audio_pad_token)
|
||||
res['y'] = res['y'].permute(1,2,0) # T B K -> B K T
|
||||
else:
|
||||
assert out['y'][0].ndim==1, out['y'][0].shape
|
||||
res['y'] = torch.nn.utils.rnn.pad_sequence(out['y'], batch_first=True, padding_value=0 if self.args.sep_special_token else self.args.audio_pad_token)
|
||||
res['y'] = torch.nn.utils.rnn.pad_sequence(out['y'], batch_first=True, padding_value=self.args.audio_pad_token)
|
||||
else:
|
||||
res['y'] = torch.stack(out['y'], dim=0)
|
||||
res["y_lens"] = torch.LongTensor(out["y_len"])
|
||||
|
|
|
@ -0,0 +1,206 @@
|
|||
import argparse
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="encode the librilight dataset using encodec model")
|
||||
parser.add_argument("--dataset_size", type=str, default='xs', help='sizes of gigaspeech, xs, s, m, l, xl. we use xl for VoiceCraft training, xs is good for debugging')
|
||||
parser.add_argument('--download_to', type=str, default="/data/scratch/pyp/datasets/gigaspeech_debug", help="dir where you want the huggingface gigaspeech dataset to be downloaded to")
|
||||
parser.add_argument('--save_dir', type=str, default="/data/scratch/pyp/datasets/gigaspeech_phn_enc_manifest_debug", help="path to the manifest, phonemes, and encodec codes dirs")
|
||||
parser.add_argument('--encodec_model_path', type=str, default="/data/scratch/pyp/exp_pyp/audiocraft/encodec/xps/6f79c6a8/checkpoint.th")
|
||||
parser.add_argument('--n_workers', type=int, default=4, help="Number of parallel worker processes")
|
||||
parser.add_argument('--mega_batch_size', type=int, default=100, help="Number of samples in each mega batch for multiprocess dataloading")
|
||||
parser.add_argument('--batch_size', type=int, default=4, help="batch size for encodec encoding, decrease it if OOM. This is the sum of batch size *over each gpu*, so increase it if you are using more gpus")
|
||||
parser.add_argument('--model_sr', type=int, default=16000, help='encodec input audio sample rate')
|
||||
parser.add_argument('--downsample_rate', type=int, default=320, help='encodec downsample rate')
|
||||
parser.add_argument('--model_code_sr', type=int, default=50, help='encodec model code sample rate')
|
||||
parser.add_argument('--len_cap', type=float, default=35.0, help='will drop audios that are longer than this number')
|
||||
parser.add_argument('--max_len', type=int, default=30000, help='max length of audio in samples, if exceed, will cut a batch into half to process, decrease this number if OOM on your machine')
|
||||
return parser.parse_args()
|
||||
if __name__ == "__main__":
|
||||
import logging
|
||||
formatter = (
|
||||
"%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
|
||||
)
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
args = parse_args()
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
import time
|
||||
from datasets import load_dataset, DownloadConfig
|
||||
|
||||
from tokenizer import TextTokenizer, tokenize_text
|
||||
|
||||
# get the path
|
||||
phn_save_root = os.path.join(args.save_dir, args.dataset_size, "phonemes")
|
||||
codes_save_root = os.path.join(args.save_dir, args.dataset_size, "encodec_16khz_4codebooks")
|
||||
vocab_fn = os.path.join(args.save_dir, args.dataset_size, "vocab.txt")
|
||||
os.makedirs(phn_save_root, exist_ok=True)
|
||||
os.makedirs(codes_save_root, exist_ok=True)
|
||||
|
||||
|
||||
def sort_by_audio_len(lens):
|
||||
inds = np.argsort(lens).tolist()
|
||||
logging.info(f"longest: {lens[inds[-1]]*args.model_code_sr} encodec codes, {lens[inds[-1]]:.2f} sec.")
|
||||
logging.info(f"shortest: {lens[inds[0]]*args.model_code_sr} encodec codes, {lens[inds[0]]:.2f} sec.")
|
||||
logging.info(f"median: {lens[inds[len(inds)//2]]*args.model_code_sr} encodec codes, {lens[inds[len(inds)//2]]:.2f} sec.")
|
||||
logging.info(f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]*args.model_code_sr} encodec codes, {lens[inds[int(len(inds)*0.95)]]:.2f} sec.")
|
||||
return inds[::-1]
|
||||
|
||||
def write_array_to_txt_file(array, filename):
|
||||
with open(filename, 'w') as f:
|
||||
for a in array[:-1]:
|
||||
f.write(' '.join(map(str, a))+'\n')
|
||||
f.write(' '.join(map(str, array[-1])))
|
||||
|
||||
|
||||
### phonemization
|
||||
# load tokenizer
|
||||
# load the encodec model
|
||||
from audiocraft.solvers import CompressionSolver
|
||||
model = CompressionSolver.model_from_checkpoint(args.encodec_model_path)
|
||||
model = model.cuda()
|
||||
model = model.eval()
|
||||
text_tokenizer = TextTokenizer()
|
||||
|
||||
|
||||
# https://github.com/SpeechColab/GigaSpeech
|
||||
# there are only four different punctuations
|
||||
# need to check whether there are other < started strings
|
||||
punc2sym = {" <COMMA>": ",", " <PERIOD>": ".", " <QUESTIONMARK>": "?", " <EXCLAMATIONPOINT>": "!"} # note the space in front of each punc name
|
||||
gar2sym = {"<SIL>": "#%#", "<MUSIC>": "##%", "<NOISE>": "%%#", "<OTHER>":"%#%"} # so that they are savely keep as the original sym when using tokenize_text
|
||||
punc2sym.update(gar2sym)
|
||||
|
||||
word2sym = { "h æ ʃ h ɐ ʃ p ɚ s ɛ n t": "<MUSIC>", "h æ ʃ p ɚ s ɛ n t h æ ʃ": "<SIL>", "p ɚ s ɛ n t h ɐ ʃ p ɚ s ɛ n t": "<OTHER>", "p ɚ s ɛ n t p ɚ s ɛ n t h æ ʃ": "<NOISE>"}
|
||||
forbidden_words = set(['#%#', '##%', '%%#', '%#%'])
|
||||
|
||||
dc = DownloadConfig(cache_dir=args.download_to)
|
||||
stime = time.time()
|
||||
logging.info("loading the dataset...")
|
||||
gs = load_dataset("speechcolab/gigaspeech", args.dataset_size, use_auth_token=True, cache_dir = args.download_to, download_config=dc)
|
||||
logging.info(f"time spend on loading the dataset: {time.time() - stime:.2f} seconds")
|
||||
|
||||
splits = ['validation', 'test', 'train']
|
||||
|
||||
logging.info(f"gigaspeech dataset {args.dataset_size} info: {gs}")
|
||||
logging.info(f"phonemizing...")
|
||||
phn_vocab = set()
|
||||
all_lens = []
|
||||
|
||||
# you will see a ton of [WARNING] words_mismatch.py:88......, it's not a issue
|
||||
for split in tqdm.tqdm(splits):
|
||||
skip = 0
|
||||
logging.info(f"now processing split {split}...")
|
||||
for item in tqdm.tqdm(gs[split]):
|
||||
save_fn = os.path.join(phn_save_root, item['segment_id']+".txt")
|
||||
text = item['text']
|
||||
if sum(word in forbidden_words for word in text.split(" ")):
|
||||
logging.info(f"skip {item['segment_id']}, because it contains forbiden words. It's transcript: {text}")
|
||||
skip += 1
|
||||
continue
|
||||
for k, v in punc2sym.items():
|
||||
text = text.replace(k, v)
|
||||
phn = tokenize_text(text_tokenizer, text)
|
||||
phn_seq = " ".join(phn)
|
||||
for k, v in word2sym.items():
|
||||
phn_seq = phn_seq.replace(k, v)
|
||||
phn_vocab.update(phn_seq.split(" "))
|
||||
all_lens.append(len(phn_seq.split(" ")))
|
||||
with open(save_fn, "w") as f:
|
||||
f.write(phn_seq)
|
||||
logging.info(f"split {split} has {len(gs[split])} samples in total, skipped {skip} due to forbiden words")
|
||||
|
||||
print(f"phn vocab size: {len(list(phn_vocab))}")
|
||||
print("phn sequence stats: ")
|
||||
print(f"longest: {max(all_lens)}")
|
||||
print(f"shortest: {min(all_lens)}")
|
||||
print(f"median: {np.quantile(all_lens, 0.5)}")
|
||||
print(f"95 percentile longest: {np.quantile(all_lens, 0.95)}")
|
||||
print("write vocabulary to ", vocab_fn)
|
||||
with open(vocab_fn, "w") as f:
|
||||
for i, phn in enumerate(list(phn_vocab)):
|
||||
if i < len(list(phn_vocab)) - 1:
|
||||
f.write(f"{str(i)} {phn}\n")
|
||||
else:
|
||||
f.write(f"{str(i)} {phn}")
|
||||
|
||||
class mydataset(torch.utils.data.Dataset):
|
||||
def __init__(self, split):
|
||||
super().__init__()
|
||||
self.data = gs[split]
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
def __getitem__(self, ind):
|
||||
try:
|
||||
segment_id, audio, sr, text, begin_time, end_time = self.data[ind]['segment_id'], torch.from_numpy(self.data[ind]['audio']['array']).float(), self.data[ind]['audio']['sampling_rate'], self.data[ind]['text'], self.data[ind]['begin_time'], self.data[ind]['end_time']
|
||||
except:
|
||||
return None, None, None, None, None, None
|
||||
|
||||
return segment_id, audio, sr, text, begin_time, end_time
|
||||
def collate(self, batch):
|
||||
res = {'segment_id': [], "audio": [], "sr": [], "text": [], "begin_time": [], "end_time": []}
|
||||
for item in batch:
|
||||
if item[0] != None:
|
||||
res['segment_id'].append(item[0])
|
||||
res['audio'].append(item[1])
|
||||
res['sr'].append(item[2])
|
||||
res['text'].append(item[3])
|
||||
res['begin_time'].append(item[4])
|
||||
res['end_time'].append(item[5])
|
||||
return res
|
||||
|
||||
|
||||
## encodec codes extraction
|
||||
logging.info("encodec encoding...")
|
||||
train_dataset = mydataset('train')
|
||||
train_loader = torch.torch.utils.data.DataLoader(train_dataset, batch_size=args.mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=train_dataset.collate)
|
||||
validation_dataset = mydataset('validation')
|
||||
validation_loader = torch.torch.utils.data.DataLoader(validation_dataset, batch_size=args.mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=validation_dataset.collate)
|
||||
test_dataset = mydataset('test')
|
||||
test_loader = torch.torch.utils.data.DataLoader(test_dataset, batch_size=args.mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=test_dataset.collate)
|
||||
splits = ['validation', 'test', 'train']
|
||||
loaders = [validation_loader, test_loader, train_loader]
|
||||
# splits = ['validation'] # for debug
|
||||
# loaders = [validation_loader]
|
||||
for split, loader in zip(splits, loaders):
|
||||
skip = 0
|
||||
logging.info(f"now processing split {split}...")
|
||||
mega_n_steps = int(np.ceil(len(gs[split]) / args.mega_batch_size))
|
||||
logging.info(f"partition the split {split} into {mega_n_steps} parts, each has {args.mega_batch_size} samples")
|
||||
for m, mega_batch in enumerate(loader):
|
||||
logging.info(f"====================================")
|
||||
logging.info(f"====================================")
|
||||
logging.info(f"now processing mega step {m+1}/{mega_n_steps}")
|
||||
lengths = np.array(mega_batch['end_time']) - np.array(mega_batch['begin_time'])
|
||||
sorted_inds = sort_by_audio_len(lengths)
|
||||
for j in range(len(sorted_inds))[::-1]:
|
||||
if lengths[sorted_inds[j]] < 0.2 or lengths[sorted_inds[j]] > args.len_cap: # skip samples that are too short (shorter than 0.2s), or too big (bigger than 80s)
|
||||
skip += 1
|
||||
del sorted_inds[j]
|
||||
|
||||
n_steps = int(np.ceil(len(sorted_inds) / args.batch_size))
|
||||
for n in tqdm.tqdm(range(n_steps), disable=True):
|
||||
inds_used = sorted_inds[n*args.batch_size:(n+1)*args.batch_size]
|
||||
audio_batch = [mega_batch['audio'][id] for id in inds_used]
|
||||
sr_batch = [mega_batch['sr'][id] for id in inds_used]
|
||||
segment_id_batch = [mega_batch['segment_id'][id] for id in inds_used]
|
||||
text_batch = [mega_batch['text'][id] for id in inds_used]
|
||||
padded_wav = torch.nn.utils.rnn.pad_sequence(audio_batch, batch_first=True).unsqueeze(1) # [B, T] -> [B, 1, T]
|
||||
all_lens = [lengths[id] for id in inds_used]
|
||||
with torch.no_grad():
|
||||
if max(all_lens) > args.max_len and len(all_lens) > 1: # NOTE decrease args.max_len if OOM, or chunk it into more than 2 forward passes
|
||||
codes = []
|
||||
inwav = padded_wav.cuda()
|
||||
codes.append(model.encode(inwav[:len(inwav)//2])[0].cpu())
|
||||
codes.append(model.encode(inwav[len(inwav)//2:])[0].cpu())
|
||||
codes = torch.cat(codes, dim=0)
|
||||
else:
|
||||
encoded_frames = model.encode(padded_wav.cuda())
|
||||
# logging.info(f"encoded_frames: {encoded_frames[0].shape}")
|
||||
codes = encoded_frames[0].cpu()
|
||||
|
||||
for i, length in enumerate(all_lens):
|
||||
save_fn = os.path.join(codes_save_root, segment_id_batch[i]+".txt")
|
||||
actual_len = round(length * args.model_code_sr) # 320 is downsample rate for this model
|
||||
cur_code = codes[i].tolist() if type(codes) == list else codes[i, :, :actual_len].tolist()
|
||||
write_array_to_txt_file(cur_code, save_fn)
|
|
@ -0,0 +1,417 @@
|
|||
name: voicecraft
|
||||
channels:
|
||||
- conda-forge
|
||||
- defaults
|
||||
dependencies:
|
||||
- _libgcc_mutex=0.1=conda_forge
|
||||
- _openmp_mutex=4.5=2_gnu
|
||||
- aom=3.8.2=h59595ed_0
|
||||
- asttokens=2.4.1=pyhd8ed1ab_0
|
||||
- atk-1.0=2.38.0=hd4edc92_1
|
||||
- audioread=3.0.1=py39hf3d152e_1
|
||||
- backcall=0.2.0=pyh9f0ad1d_0
|
||||
- baumwelch=0.3.7=h00ab1b0_5
|
||||
- biopython=1.79=py39hb9d737c_3
|
||||
- brotli=1.1.0=hd590300_1
|
||||
- brotli-bin=1.1.0=hd590300_1
|
||||
- brotli-python=1.1.0=py39h3d6467e_1
|
||||
- bzip2=1.0.8=hd590300_5
|
||||
- ca-certificates=2024.2.2=hbcca054_0
|
||||
- cairo=1.18.0=h3faef2a_0
|
||||
- certifi=2024.2.2=pyhd8ed1ab_0
|
||||
- cffi=1.16.0=py39h7a31438_0
|
||||
- charset-normalizer=3.3.2=pyhd8ed1ab_0
|
||||
- click=8.1.7=unix_pyh707e725_0
|
||||
- colorama=0.4.6=pyhd8ed1ab_0
|
||||
- comm=0.2.2=pyhd8ed1ab_0
|
||||
- contourpy=1.2.0=py39h7633fee_0
|
||||
- cycler=0.12.1=pyhd8ed1ab_0
|
||||
- dataclassy=1.0.1=pyhd8ed1ab_0
|
||||
- dav1d=1.2.1=hd590300_0
|
||||
- debugpy=1.8.1=py39h3d6467e_0
|
||||
- decorator=5.1.1=pyhd8ed1ab_0
|
||||
- executing=2.0.1=pyhd8ed1ab_0
|
||||
- expat=2.6.2=h59595ed_0
|
||||
- ffmpeg=6.1.1=gpl_h38e077a_106
|
||||
- font-ttf-dejavu-sans-mono=2.37=hab24e00_0
|
||||
- font-ttf-inconsolata=3.000=h77eed37_0
|
||||
- font-ttf-source-code-pro=2.038=h77eed37_0
|
||||
- font-ttf-ubuntu=0.83=h77eed37_1
|
||||
- fontconfig=2.14.2=h14ed4e7_0
|
||||
- fonts-conda-ecosystem=1=0
|
||||
- fonts-conda-forge=1=0
|
||||
- fonttools=4.49.0=py39hd1e30aa_0
|
||||
- freetype=2.12.1=h267a509_2
|
||||
- fribidi=1.0.10=h36c2ea0_0
|
||||
- gdk-pixbuf=2.42.10=h829c605_5
|
||||
- gettext=0.21.1=h27087fc_0
|
||||
- giflib=5.2.1=h0b41bf4_3
|
||||
- gmp=6.3.0=h59595ed_1
|
||||
- gnutls=3.7.9=hb077bed_0
|
||||
- graphite2=1.3.13=h58526e2_1001
|
||||
- graphviz=9.0.0=h78e8752_1
|
||||
- greenlet=3.0.3=py39h3d6467e_0
|
||||
- gtk2=2.24.33=h280cfa0_4
|
||||
- gts=0.7.6=h977cf35_4
|
||||
- harfbuzz=8.3.0=h3d44ed6_0
|
||||
- hdbscan=0.8.33=py39h44dd56e_4
|
||||
- icu=73.2=h59595ed_0
|
||||
- idna=3.6=pyhd8ed1ab_0
|
||||
- importlib-metadata=7.0.2=pyha770c72_0
|
||||
- importlib-resources=6.3.0=pyhd8ed1ab_0
|
||||
- importlib_metadata=7.0.2=hd8ed1ab_0
|
||||
- importlib_resources=6.3.0=pyhd8ed1ab_0
|
||||
- ipykernel=6.29.3=pyhd33586a_0
|
||||
- jedi=0.19.1=pyhd8ed1ab_0
|
||||
- joblib=1.3.2=pyhd8ed1ab_0
|
||||
- jupyter_client=8.6.1=pyhd8ed1ab_0
|
||||
- jupyter_core=5.7.2=py39hf3d152e_0
|
||||
- kaldi=5.5.1068=cpu_h31769b2_2
|
||||
- keyutils=1.6.1=h166bdaf_0
|
||||
- kiwisolver=1.4.5=py39h7633fee_1
|
||||
- kneed=0.8.5=pyhd8ed1ab_0
|
||||
- krb5=1.21.2=h659d440_0
|
||||
- lame=3.100=h166bdaf_1003
|
||||
- lazy_loader=0.3=pyhd8ed1ab_0
|
||||
- lcms2=2.16=hb7c19ff_0
|
||||
- ld_impl_linux-64=2.40=h41732ed_0
|
||||
- lerc=4.0.0=h27087fc_0
|
||||
- libabseil=20240116.1=cxx17_h59595ed_2
|
||||
- libass=0.17.1=h8fe9dca_1
|
||||
- libblas=3.9.0=21_linux64_openblas
|
||||
- libbrotlicommon=1.1.0=hd590300_1
|
||||
- libbrotlidec=1.1.0=hd590300_1
|
||||
- libbrotlienc=1.1.0=hd590300_1
|
||||
- libcblas=3.9.0=21_linux64_openblas
|
||||
- libclang-cpp15=15.0.7=default_hb11cfb5_4
|
||||
- libdeflate=1.19=hd590300_0
|
||||
- libdrm=2.4.120=hd590300_0
|
||||
- libedit=3.1.20191231=he28a2e2_2
|
||||
- libexpat=2.6.2=h59595ed_0
|
||||
- libffi=3.4.2=h7f98852_5
|
||||
- libflac=1.4.3=h59595ed_0
|
||||
- libgcc-ng=13.2.0=h807b86a_5
|
||||
- libgd=2.3.3=h119a65a_9
|
||||
- libgfortran-ng=13.2.0=h69a702a_5
|
||||
- libgfortran5=13.2.0=ha4646dd_5
|
||||
- libglib=2.80.0=hf2295e7_0
|
||||
- libgomp=13.2.0=h807b86a_5
|
||||
- libhwloc=2.9.3=default_h554bfaf_1009
|
||||
- libiconv=1.17=hd590300_2
|
||||
- libidn2=2.3.7=hd590300_0
|
||||
- libjpeg-turbo=3.0.0=hd590300_1
|
||||
- liblapack=3.9.0=21_linux64_openblas
|
||||
- liblapacke=3.9.0=21_linux64_openblas
|
||||
- libllvm14=14.0.6=hcd5def8_4
|
||||
- libllvm15=15.0.7=hb3ce162_4
|
||||
- libllvmspirv15=15.0.0=h0cdce71_1
|
||||
- libnsl=2.0.1=hd590300_0
|
||||
- libogg=1.3.4=h7f98852_1
|
||||
- libopenblas=0.3.26=pthreads_h413a1c8_0
|
||||
- libopenvino=2024.0.0=h2e90f83_1
|
||||
- libopenvino-auto-batch-plugin=2024.0.0=hd5fc58b_1
|
||||
- libopenvino-auto-plugin=2024.0.0=hd5fc58b_1
|
||||
- libopenvino-hetero-plugin=2024.0.0=h3ecfda7_1
|
||||
- libopenvino-intel-cpu-plugin=2024.0.0=h2e90f83_1
|
||||
- libopenvino-intel-gpu-plugin=2024.0.0=h2e90f83_1
|
||||
- libopenvino-ir-frontend=2024.0.0=h3ecfda7_1
|
||||
- libopenvino-onnx-frontend=2024.0.0=h757c851_1
|
||||
- libopenvino-paddle-frontend=2024.0.0=h757c851_1
|
||||
- libopenvino-pytorch-frontend=2024.0.0=h59595ed_1
|
||||
- libopenvino-tensorflow-frontend=2024.0.0=hca94c1a_1
|
||||
- libopenvino-tensorflow-lite-frontend=2024.0.0=h59595ed_1
|
||||
- libopus=1.3.1=h7f98852_1
|
||||
- libpciaccess=0.18=hd590300_0
|
||||
- libpng=1.6.43=h2797004_0
|
||||
- libpq=16.2=h33b98f1_0
|
||||
- libprotobuf=4.25.3=h08a7969_0
|
||||
- librosa=0.10.1=pyhd8ed1ab_0
|
||||
- librsvg=2.56.3=he3f83f7_1
|
||||
- libsndfile=1.2.2=hc60ed4a_1
|
||||
- libsodium=1.0.18=h36c2ea0_1
|
||||
- libsqlite=3.45.2=h2797004_0
|
||||
- libstdcxx-ng=13.2.0=h7e041cc_5
|
||||
- libtasn1=4.19.0=h166bdaf_0
|
||||
- libtiff=4.6.0=ha9c0a0a_2
|
||||
- libunistring=0.9.10=h7f98852_0
|
||||
- libuuid=2.38.1=h0b41bf4_0
|
||||
- libva=2.21.0=hd590300_0
|
||||
- libvorbis=1.3.7=h9c3ff4c_0
|
||||
- libvpx=1.14.0=h59595ed_0
|
||||
- libwebp=1.3.2=h658648e_1
|
||||
- libwebp-base=1.3.2=hd590300_0
|
||||
- libxcb=1.15=h0b41bf4_0
|
||||
- libxcrypt=4.4.36=hd590300_1
|
||||
- libxml2=2.12.5=h232c23b_0
|
||||
- libzlib=1.2.13=hd590300_5
|
||||
- llvm-spirv-15=15.0.0=h0cdce71_1
|
||||
- mad=0.15.1b=h9c3ff4c_1
|
||||
- markdown-it-py=3.0.0=pyhd8ed1ab_0
|
||||
- matplotlib-base=3.8.3=py39he9076e7_0
|
||||
- matplotlib-inline=0.1.6=pyhd8ed1ab_0
|
||||
- mdurl=0.1.2=pyhd8ed1ab_0
|
||||
- montreal-forced-aligner=2.2.17=pyhd8ed1ab_0
|
||||
- mpg123=1.32.4=h59595ed_0
|
||||
- msgpack-python=1.0.7=py39h7633fee_0
|
||||
- munkres=1.1.4=pyh9f0ad1d_0
|
||||
- ncurses=6.4=h59595ed_2
|
||||
- nest-asyncio=1.6.0=pyhd8ed1ab_0
|
||||
- nettle=3.9.1=h7ab15ed_0
|
||||
- ngram=1.3.14=h924138e_2
|
||||
- numba=0.59.0=py39h615d6bd_1
|
||||
- numpy=1.26.4=py39h474f0d3_0
|
||||
- ocl-icd=2.3.2=hd590300_0
|
||||
- openfst=1.8.2=h924138e_2
|
||||
- openh264=2.4.1=h59595ed_0
|
||||
- openjpeg=2.5.2=h488ebb8_0
|
||||
- openssl=3.2.1=hd590300_0
|
||||
- p11-kit=0.24.1=hc5aa10d_0
|
||||
- packaging=24.0=pyhd8ed1ab_0
|
||||
- pandas=2.2.1=py39hddac248_0
|
||||
- pango=1.52.1=ha41ecd1_0
|
||||
- parso=0.8.3=pyhd8ed1ab_0
|
||||
- patsy=0.5.6=pyhd8ed1ab_0
|
||||
- pcre2=10.43=hcad00b1_0
|
||||
- pexpect=4.9.0=pyhd8ed1ab_0
|
||||
- pgvector-python=0.2.5=pyhe093146_0
|
||||
- pickleshare=0.7.5=py_1003
|
||||
- pillow=10.2.0=py39had0adad_0
|
||||
- pip=24.0=pyhd8ed1ab_0
|
||||
- pixman=0.43.2=h59595ed_0
|
||||
- platformdirs=4.2.0=pyhd8ed1ab_0
|
||||
- pocl=5.0=h03a6ac1_2
|
||||
- pocl-core=5.0=hdaecddf_2
|
||||
- pocl-cpu=5.0=he901f76_2
|
||||
- pocl-cpu-minimal=5.0=h5ccd973_2
|
||||
- pocl-cuda=5.0=hdaecddf_2
|
||||
- pocl-remote=5.0=h5ccd973_2
|
||||
- pooch=1.8.1=pyhd8ed1ab_0
|
||||
- postgresql=16.2=h7387d8b_0
|
||||
- prompt-toolkit=3.0.42=pyha770c72_0
|
||||
- prompt_toolkit=3.0.42=hd8ed1ab_0
|
||||
- psutil=5.9.8=py39hd1e30aa_0
|
||||
- psycopg2=2.9.9=py39h89197e3_0
|
||||
- pthread-stubs=0.4=h36c2ea0_1001
|
||||
- ptyprocess=0.7.0=pyhd3deb0d_0
|
||||
- pugixml=1.14=h59595ed_0
|
||||
- pure_eval=0.2.2=pyhd8ed1ab_0
|
||||
- pycparser=2.21=pyhd8ed1ab_0
|
||||
- pygments=2.17.2=pyhd8ed1ab_0
|
||||
- pyparsing=3.1.2=pyhd8ed1ab_0
|
||||
- pysocks=1.7.1=pyha2e5f31_6
|
||||
- pysoundfile=0.12.1=pypyhd8ed1ab_1
|
||||
- python=3.9.18=h0755675_1_cpython
|
||||
- python-tzdata=2024.1=pyhd8ed1ab_0
|
||||
- python_abi=3.9=4_cp39
|
||||
- pytz=2024.1=pyhd8ed1ab_0
|
||||
- pyyaml=6.0.1=py39hd1e30aa_1
|
||||
- pyzmq=25.1.2=py39h8c080ef_0
|
||||
- readline=8.2=h8228510_1
|
||||
- requests=2.31.0=pyhd8ed1ab_0
|
||||
- rich=13.7.1=pyhd8ed1ab_0
|
||||
- rich-click=1.7.4=pyhd8ed1ab_0
|
||||
- scikit-learn=1.2.2=py39hc236052_2
|
||||
- scipy=1.12.0=py39h474f0d3_2
|
||||
- seaborn=0.13.2=hd8ed1ab_0
|
||||
- seaborn-base=0.13.2=pyhd8ed1ab_0
|
||||
- setuptools=69.2.0=pyhd8ed1ab_0
|
||||
- six=1.16.0=pyh6c4a22f_0
|
||||
- snappy=1.1.10=h9fff704_0
|
||||
- sox=14.4.2=ha5cc309_1018
|
||||
- soxr=0.1.3=h0b41bf4_3
|
||||
- soxr-python=0.3.7=py39h44dd56e_0
|
||||
- sqlalchemy=2.0.28=py39hd1e30aa_0
|
||||
- sqlite=3.45.2=h2c6b66d_0
|
||||
- stack_data=0.6.2=pyhd8ed1ab_0
|
||||
- statsmodels=0.14.1=py39h44dd56e_0
|
||||
- svt-av1=1.8.0=h59595ed_0
|
||||
- tbb=2021.11.0=h00ab1b0_1
|
||||
- threadpoolctl=3.3.0=pyhc1e730c_0
|
||||
- tk=8.6.13=noxft_h4845f30_101
|
||||
- tornado=6.4=py39hd1e30aa_0
|
||||
- tqdm=4.66.2=pyhd8ed1ab_0
|
||||
- traitlets=5.14.2=pyhd8ed1ab_0
|
||||
- typing-extensions=4.10.0=hd8ed1ab_0
|
||||
- typing_extensions=4.10.0=pyha770c72_0
|
||||
- tzcode=2024a=h3f72095_0
|
||||
- tzdata=2024a=h0c530f3_0
|
||||
- unicodedata2=15.1.0=py39hd1e30aa_0
|
||||
- urllib3=2.2.1=pyhd8ed1ab_0
|
||||
- wcwidth=0.2.13=pyhd8ed1ab_0
|
||||
- wheel=0.42.0=pyhd8ed1ab_0
|
||||
- x264=1!164.3095=h166bdaf_2
|
||||
- x265=3.5=h924138e_3
|
||||
- xorg-fixesproto=5.0=h7f98852_1002
|
||||
- xorg-kbproto=1.0.7=h7f98852_1002
|
||||
- xorg-libice=1.1.1=hd590300_0
|
||||
- xorg-libsm=1.2.4=h7391055_0
|
||||
- xorg-libx11=1.8.7=h8ee46fc_0
|
||||
- xorg-libxau=1.0.11=hd590300_0
|
||||
- xorg-libxdmcp=1.1.3=h7f98852_0
|
||||
- xorg-libxext=1.3.4=h0b41bf4_2
|
||||
- xorg-libxfixes=5.0.3=h7f98852_1004
|
||||
- xorg-libxrender=0.9.11=hd590300_0
|
||||
- xorg-renderproto=0.11.1=h7f98852_1002
|
||||
- xorg-xextproto=7.3.0=h0b41bf4_1003
|
||||
- xorg-xproto=7.0.31=h7f98852_1007
|
||||
- xz=5.2.6=h166bdaf_0
|
||||
- yaml=0.2.5=h7f98852_2
|
||||
- zeromq=4.3.5=h59595ed_1
|
||||
- zipp=3.17.0=pyhd8ed1ab_0
|
||||
- zlib=1.2.13=hd590300_5
|
||||
- zstd=1.5.5=hfc55251_0
|
||||
- pip:
|
||||
- absl-py==2.1.0
|
||||
- aiofiles==23.2.1
|
||||
- aiohttp==3.9.3
|
||||
- aiosignal==1.3.1
|
||||
- altair==5.2.0
|
||||
- antlr4-python3-runtime==4.9.3
|
||||
- anyio==4.3.0
|
||||
- async-timeout==4.0.3
|
||||
- attrs==23.2.0
|
||||
- av==11.0.0
|
||||
- babel==2.14.0
|
||||
- beautifulsoup4==4.12.3
|
||||
- bibtexparser==2.0.0b7
|
||||
- bleach==6.1.0
|
||||
- blis==0.7.11
|
||||
- catalogue==2.0.10
|
||||
- clldutils==3.22.2
|
||||
- cloudpickle==3.0.0
|
||||
- cmake==3.28.3
|
||||
- colorlog==6.8.2
|
||||
- confection==0.1.4
|
||||
- csvw==3.3.0
|
||||
- cymem==2.0.8
|
||||
- cython==0.29.37
|
||||
- datasets==2.16.0
|
||||
- defusedxml==0.7.1
|
||||
- demucs==4.0.1
|
||||
- dill==0.3.6
|
||||
- dlinfo==1.2.1
|
||||
- docopt==0.6.2
|
||||
- dora-search==0.1.12
|
||||
- einops==0.7.0
|
||||
- encodec==0.1.1
|
||||
- exceptiongroup==1.2.0
|
||||
- fastapi==0.110.0
|
||||
- fastjsonschema==2.19.1
|
||||
- ffmpy==0.3.2
|
||||
- filelock==3.13.1
|
||||
- flashy==0.0.2
|
||||
- frozenlist==1.4.1
|
||||
- fsspec==2023.10.0
|
||||
- gradio==3.50.2
|
||||
- gradio-client==0.6.1
|
||||
- grpcio==1.62.1
|
||||
- h11==0.14.0
|
||||
- httpcore==1.0.4
|
||||
- httpx==0.27.0
|
||||
- huggingface-hub==0.21.4
|
||||
- hydra-colorlog==1.2.0
|
||||
- hydra-core==1.3.2
|
||||
- ipython==8.12.3
|
||||
- isodate==0.6.1
|
||||
- jinja2==3.1.3
|
||||
- jsonschema==4.21.1
|
||||
- jsonschema-specifications==2023.12.1
|
||||
- julius==0.2.7
|
||||
- jupyterlab-pygments==0.3.0
|
||||
- lameenc==1.7.0
|
||||
- langcodes==3.3.0
|
||||
- language-tags==1.2.0
|
||||
- lit==18.1.1
|
||||
- llvmlite==0.42.0
|
||||
- lxml==5.1.0
|
||||
- markdown==3.5.2
|
||||
- markupsafe==2.1.5
|
||||
- mistune==3.0.2
|
||||
- mpmath==1.3.0
|
||||
- msgpack==1.0.8
|
||||
- multidict==6.0.5
|
||||
- multiprocess==0.70.14
|
||||
- murmurhash==1.0.10
|
||||
- nbclient==0.10.0
|
||||
- nbconvert==7.16.3
|
||||
- nbformat==5.10.3
|
||||
- networkx==3.2.1
|
||||
- num2words==0.5.13
|
||||
- nvidia-cublas-cu11==11.10.3.66
|
||||
- nvidia-cuda-cupti-cu11==11.7.101
|
||||
- nvidia-cuda-nvrtc-cu11==11.7.99
|
||||
- nvidia-cuda-runtime-cu11==11.7.99
|
||||
- nvidia-cudnn-cu11==8.5.0.96
|
||||
- nvidia-cufft-cu11==10.9.0.58
|
||||
- nvidia-curand-cu11==10.2.10.91
|
||||
- nvidia-cusolver-cu11==11.4.0.1
|
||||
- nvidia-cusparse-cu11==11.7.4.91
|
||||
- nvidia-nccl-cu11==2.14.3
|
||||
- nvidia-nvtx-cu11==11.7.91
|
||||
- omegaconf==2.3.0
|
||||
- openunmix==1.2.1
|
||||
- orjson==3.9.15
|
||||
- pandocfilters==1.5.1
|
||||
- pathlib-abc==0.1.1
|
||||
- pathy==0.11.0
|
||||
- pgvector==0.2.2
|
||||
- phonemizer==3.2.1
|
||||
- pipreqs==0.5.0
|
||||
- praatio==6.2.0
|
||||
- preshed==3.0.9
|
||||
- protobuf==4.25.3
|
||||
- pyarrow==15.0.2
|
||||
- pyarrow-hotfix==0.6
|
||||
- pydantic==1.10.14
|
||||
- pydub==0.25.1
|
||||
- pylatexenc==2.10
|
||||
- pynini==2.1.6
|
||||
- pypinyin==0.48.0
|
||||
- python-dateutil==2.9.0.post0
|
||||
- python-multipart==0.0.9
|
||||
- rdflib==7.0.0
|
||||
- referencing==0.33.0
|
||||
- regex==2023.12.25
|
||||
- responses==0.18.0
|
||||
- retrying==1.3.4
|
||||
- rfc3986==1.5.0
|
||||
- rpds-py==0.18.0
|
||||
- safetensors==0.4.2
|
||||
- segments==2.2.1
|
||||
- semantic-version==2.10.0
|
||||
- sentencepiece==0.2.0
|
||||
- smart-open==6.4.0
|
||||
- sniffio==1.3.1
|
||||
- soupsieve==2.5
|
||||
- spacy==3.5.2
|
||||
- spacy-legacy==3.0.12
|
||||
- spacy-loggers==1.0.5
|
||||
- srsly==2.4.8
|
||||
- starlette==0.36.3
|
||||
- submitit==1.5.1
|
||||
- sympy==1.12
|
||||
- tabulate==0.9.0
|
||||
- tensorboard==2.16.2
|
||||
- tensorboard-data-server==0.7.2
|
||||
- thinc==8.1.12
|
||||
- tinycss2==1.2.1
|
||||
- tokenizers==0.15.2
|
||||
- toolz==0.12.1
|
||||
- torch==2.0.1
|
||||
- torchaudio==2.0.2
|
||||
- torchmetrics==0.11.1
|
||||
- transformers==4.38.2
|
||||
- treetable==0.2.5
|
||||
- triton==2.0.0
|
||||
- typer==0.7.0
|
||||
- uritemplate==4.1.1
|
||||
- uvicorn==0.28.0
|
||||
- wasabi==1.1.2
|
||||
- webencodings==0.5.1
|
||||
- websockets==11.0.3
|
||||
- werkzeug==3.0.1
|
||||
- xformers==0.0.22
|
||||
- xxhash==3.4.1
|
||||
- yarg==0.1.9
|
||||
- yarl==1.9.4
|
||||
prefix: /home/pyp/miniconda3/envs/voicecraft
|
|
@ -504,7 +504,7 @@ class VoiceCraft(nn.Module):
|
|||
ntokens = []
|
||||
top10acc = []
|
||||
for k, (logit, target) in enumerate(zip(logits, targets)):
|
||||
loss.append(F.cross_entropy(logit, target, reduction='mean', weight=self.class_weight.data if self.args.eog_weight!=1 else None))
|
||||
loss.append(F.cross_entropy(logit, target, reduction='mean'))
|
||||
top10acc.append(self.accuracy_metrics[k](logit.detach(), target))
|
||||
ntokens.append(len(logit))
|
||||
|
||||
|
@ -988,6 +988,8 @@ class VoiceCraft(nn.Module):
|
|||
for jj in range(1,self.args.n_codebooks):
|
||||
logits_adjust[jj][eog_inference] = -10000
|
||||
logits_adjust[jj][self.args.empty_token] = -10000
|
||||
if cur_num_gen <= self.args.encodec_sr // 5: # this shouldn't happen, but just in case the model stopped too early
|
||||
logits_adjust[0][eog_inference] = -10000
|
||||
##################### silence repetition handling #####################
|
||||
if stop_repetition > 0 and prev_token in silence_tokens and consec_silence_count > stop_repetition:
|
||||
if logits_adjust[0, prev_token] < 0:
|
||||
|
@ -1237,6 +1239,8 @@ class VoiceCraft(nn.Module):
|
|||
for jj in range(1,self.args.n_codebooks):
|
||||
logits_adjust[:,jj,eog_inference] = -10000
|
||||
logits_adjust[:,jj,self.args.empty_token] = -10000
|
||||
if cur_num_gen <= self.args.encodec_sr // 5: # this shouldn't happen, but just in case the model stopped too early
|
||||
logits_adjust[:,:,eog_inference] = -10000
|
||||
##################### silence repetition handling #####################
|
||||
for b in range(batch_size):
|
||||
prev_token = prev_tokens[b]
|
||||
|
|
|
@ -7,9 +7,9 @@ export WORLD_SIZE=4
|
|||
dataset=gigaspeech
|
||||
mkdir -p ./logs/${dataset}
|
||||
|
||||
exp_root="/data/scratch/pyp/exp_pyp/VoiceCraft"
|
||||
exp_root="path/to/store/exp_results"
|
||||
exp_name=e830M
|
||||
dataset_dir="/data/scratch/pyp/datasets/gigaspeech_phn_enc_manifest/xl"
|
||||
dataset_dir="path/to/stored_extracted_codes_and_phonemes/xl" # xs if you only extracted xs in previous step
|
||||
encodec_codes_folder_name="encodec_16khz_4codebooks"
|
||||
|
||||
# export CUDA_LAUNCH_BLOCKING=1 # for debugging
|
||||
|
@ -51,7 +51,7 @@ torchrun --nnodes=1 --rdzv-backend=c10d --rdzv-endpoint=localhost:41977 --nproc_
|
|||
--text_vocab_size 100 \
|
||||
--text_pad_token 100 \
|
||||
--phn_folder_name "phonemes" \
|
||||
--manifest_name "manifest_large16khz_lessambi" \
|
||||
--manifest_name "manifest" \
|
||||
--encodec_folder_name ${encodec_codes_folder_name} \
|
||||
--audio_vocab_size 2048 \
|
||||
--empty_token 2048 \
|
||||
|
|
Loading…
Reference in New Issue