extraction,training,data,weights

This commit is contained in:
jason-on-salt-a40 2024-03-24 19:43:37 -07:00
parent d754e9109a
commit a129883910
7 changed files with 686 additions and 176 deletions

View File

@ -1,7 +1,7 @@
# VoiceCraft: Zero-Shot Speech Editing and Text-to-Speech in the Wild
[Demo](https://jasonppy.github.io/VoiceCraft_web) [Paper](https://jasonppy.github.io/assets/pdfs/VoiceCraft.pdf)
TL;DR:
### TL;DR
VoiceCraft is a token infilling neural codec language model, that achieves state-of-the-art performance on both **speech editing** and **zero-shot text-to-speech (TTS)** on in-the-wild data including audiobooks, internet videos, and podcasts.
To clone or edit an unseen voice, VoiceCraft needs only a few seconds of reference.
@ -12,22 +12,25 @@ The TODOs left will be completed by the end of March 2024.
- [x] Codebase upload
- [x] Environment setup
- [x] Inference demo for speech editing and TTS
- [ ] Upload model weights
- [ ] Training guidance
- [ ] Upload the RealEdit dataset
- [x] Training guidance
- [x] Upload the RealEdit dataset and training manifest
- [ ] Upload model weights (encodec weights are up)
## Environment setup
```bash
conda create -n voicecraft python=3.9.16
conda activate voicecraft
pip install torch==2.0.1 torchaudio==2.0.2 # this assumes your system is compatible with CUDA 11.7, otherwise checkout https://pytorch.org/get-started/previous-versions/#v201
pip install torch==2.0.1 # this assumes your system is compatible with CUDA 11.7, otherwise checkout https://pytorch.org/get-started/previous-versions/#v201
apt-get install ffmpeg # if you don't already have ffmpeg installed
pip install -e git+https://github.com/facebookresearch/audiocraft.git@c5157b5bf14bf83449c17ea1eeb66c19fb4bc7f0#egg=audiocraft
apt-get install espeak-ng # backend for the phonemizer installed below
pip install tensorboard=2.16.2
pip install phonemizer==3.2.1
pip install tensorboard
pip install datasets==2.12.0
pip install torchaudio==2.0.2
pip install datasets==2.16.0
pip install torchmetrics==0.11.1
# install MFA for getting forced-alignment, this could take a few minutes
conda install -c conda-forge montreal-forced-aligner=2.2.17 openfst=1.8.2 kaldi=5.5.1068
# conda install pocl # above gives an warning for installing pocl, not sure if really need this
@ -36,9 +39,51 @@ conda install -c conda-forge montreal-forced-aligner=2.2.17 openfst=1.8.2 kaldi=
conda install -n voicecraft ipykernel --update-deps --force-reinstall
```
If you have encountered version issues when running things, checkout [environment.yml](./environment.yml) for exact matching.
## Inference Examples
Checkout [`inference_speech_editing.ipynb`](./inference_speech_editing.ipynb) and [`inference_tts.ipynb`](./inference_tts.ipynb)
## Training
To train an VoiceCraft model, you need to prepare the following parts:
1. utterances and their transcripts
2. encode the utterances into codes using e.g. Encodec
3. convert transcripts into phoneme sequence, and a phoneme set (we named it vocab.txt)
4. manifest (i.e. metadata)
Step 1,2,3 are handled in [./data/phonemize_encodec_encode_hf.py](./data/phonemize_encodec_encode_hf.py), where
1. Gigaspeech is downloaded through HuggingFace. Note that you need to sign an agreement in order to download the dataset (it needs your auth token)
2. phoneme sequence and encodec codes are also extracted using the script.
An example run:
```bash
conda activate voicecraft
export CUDA_VISIBLE_DEVICES=0
cd ./data
python phonemize_encodec_encode_hf.py \
--dataset_size xs \
--download_to path/to/store_huggingface_downloads \
--save_dir path/to/store_extracted_codes_and_phonemes \
--encodec_model_path path/to/encodec_model \
--mega_batch_size 120 \
--batch_size 32 \
--max_len 30000
```
where encodec_model_path is avaliable [here](https://huggingface.co/pyp1/VoiceCraft). This model is trained on Gigaspeech XL, it has 56M parameters, 4 codebooks, each codebook has 2048 codes. Details are described in our [paper](https://jasonppy.github.io/assets/pdfs/VoiceCraft.pdf). If you encounter OOM during extraction, try decrease the batch_size and/or max_len.
The extracted codes, phonemes, and vocab.txt will be stored at `path/to/store_extracted_codes_and_phonemes/${dataset_size}/{encodec_16khz_4codebooks,phonemes,vocab.txt}`.
As for manifest, please download train.txt and validation.txt from [here](https://huggingface.co/datasets/pyp1/VoiceCraft_RealEdit/tree/main), and put them under `path/to/store_extracted_codes_and_phonemes/manifest/`. Please also download vocab.txt from [here](https://huggingface.co/datasets/pyp1/VoiceCraft_RealEdit/tree/main) if you want to use our pretrained VoiceCraft model (so that the phoneme-to-token matching is the same).
Now, you are good to start training!
```bash
conda activate voicecraft
cd ./z_scripts
bash e830M.sh
```
## License
The codebase is under CC BY-NC-SA 4.0 ([LICENSE-CODE](./LICENSE-CODE)), and the model weights are under Coqui Public Model License 1.0.0 ([LICENSE-MODEL](./LICENSE-MODEL)). Note that we use some of the code from other repository that are under different licenses: `./models/codebooks_patterns.py` is under MIT license; `./models/modules`, `./steps/optim.py`, `data/tokenizer.py` are under Apache License, Version 2.0; the phonemizer we used is under GNU 3.0 License. For drop-in replacement of the phonemizer (i.e. text to IPA phoneme mapping), try [g2p](https://github.com/roedoejet/g2p) (MIT License) or [OpenPhonemizer](https://github.com/NeuralVox/OpenPhonemizer) (BSD-3-Clause Clear), although these are not tested.

View File

@ -1,160 +0,0 @@
import argparse
def parse_args():
parser = argparse.ArgumentParser(description="encode the librilight dataset using encodec model")
parser.add_argument("--manifest_root", type=str, default="/home/pyp/audiocraft/egs/gigaspeech", help="this the dir of the audiocraft manifest!")
parser.add_argument('--audio_dir', type=str, default="/data/scratch/pyp/datasets/gigaspeech_flac", help="Path dirs of the flac audio files")
parser.add_argument('--save_dir', type=str, default="/data/scratch/pyp/datasets/gigaspeech_phn_enc_manifest/xl", help="path to the manifest, phonemes, and encodec codes dirs")
parser.add_argument('--encodec_model_path', type=str, default="/data/scratch/pyp/exp_pyp/audiocraft/encodec/xps/6f79c6a8/checkpoint.th")
parser.add_argument('--n_workers', type=int, default=32, help="Number of parallel worker processes")
parser.add_argument('--batch_size', type=int, default=64, help="batch size for encodec encoding, decrease it if OOM. This is the sum of batch size *over each gpu*, so increase it if you are using more gpus")
parser.add_argument('--model_sr', type=int, default=16000, help='encodec input audio sample rate')
parser.add_argument('--downsample_rate', type=int, default=320, help='encodec downsample rate')
parser.add_argument('--model_code_sr', type=int, default=50, help='encodec model code sample rate')
parser.add_argument('--len_cap', type=float, default=35.0, help='will drop audios that are longer than this number')
return parser.parse_args()
if __name__ == "__main__":
import logging
formatter = (
"%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
import os
import numpy as np
import torch
import torchaudio
import tqdm
import time
args = parse_args()
manifest_dir = args.manifest_root # this dir is scp-ed
audio_dir = args.audio_dir # this is scp-ed flac dir
encodec_signature = args.encodec_model_path.split("/")[-2]
save_codes_dir = os.path.join(args.save_dir, f"encodec_16khz_{encodec_signature}")
os.makedirs(save_codes_dir, exist_ok=True)
# model_sr = 16000
# downsample_rate = 320
# model_code_sr = 50
def sort_by_audio_len(lens):
inds = np.argsort(lens).tolist()
logging.info(f"longest: {lens[inds[-1]]/args.downsample_rate} encodec codes, {lens[inds[-1]]/args.model_sr:.2f} sec.")
logging.info(f"shortest: {lens[inds[0]]/args.downsample_rate} encodec codes, {lens[inds[0]]/args.model_sr:.2f} sec.")
logging.info(f"median: {lens[inds[len(inds)//2]]/args.downsample_rate} encodec codes, {lens[inds[len(inds)//2]]/args.model_sr:.2f} sec.")
logging.info(f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]/args.downsample_rate} encodec codes, {lens[inds[int(len(inds)*0.95)]]/args.model_sr:.2f} sec.")
return inds[::-1]
def write_array_to_txt_file(array, filename):
with open(filename, 'w') as f:
for a in array[:-1]:
f.write(' '.join(map(str, a))+'\n')
f.write(' '.join(map(str, array[-1])))
class mydataset(torch.utils.data.Dataset):
def __init__(self, split):
super().__init__()
# self.data = gs[split]
self.split = split
self.audio_root = audio_dir
manifest_fn = os.path.join(manifest_dir, split+".txt")
with open(manifest_fn, "r") as rf:
self.data = [l.strip().split("\t") for l in rf.readlines()]
def __len__(self):
return len(self.data)
def __getitem__(self, ind):
try:
afn = self.data[ind][0]
fn = os.path.join(self.audio_root, afn)
audio, sr = torchaudio.load(fn)
assert sr == args.model_sr, sr
except Exception as e:
logging.info(f"{e}")
return None, None, None
assert audio.ndim==2 and audio.shape[0] == 1, audio.shape
return audio.type(torch.float32).squeeze(0), audio.shape[-1], os.path.basename(afn).split(".")[0]
def collate(self, batch):
lens, audios, segment_ids = [], [], []
for item in batch:
if item[0] != None:
audios.append(item[0])
lens.append(item[1])
segment_ids.append(item[2])
return audios, lens, segment_ids
# load the encodec model
from audiocraft.solvers import CompressionSolver
model = CompressionSolver.model_from_checkpoint(args.encodec_model_path)
model = model.cuda()
model = model.eval()
model = torch.nn.DataParallel(model)
# setup dataloader
mega_batch_size = 2100
batch_size = args.batch_size
train_dataset = mydataset('train')
train_loader = torch.torch.utils.data.DataLoader(train_dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=train_dataset.collate)
validation_dataset = mydataset('validation')
validation_loader = torch.torch.utils.data.DataLoader(validation_dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=validation_dataset.collate)
test_dataset = mydataset('test')
test_loader = torch.torch.utils.data.DataLoader(test_dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=test_dataset.collate)
splits = ['validation', 'test', 'train']
loaders = [validation_loader, test_loader, train_loader]
# splits = ['validation'] # NOTE this is for debug, for example, see if the
# loaders = [validation_loader]
for split, loader in zip(splits, loaders):
skip = 0
logging.info(f"now processing split {split}...")
mega_n_steps = int(np.ceil(len(loader.dataset) / mega_batch_size))
# mega_n_steps = int(np.ceil(len(gs) / mega_batch_size))
logging.info(f"partition the split {split} into {mega_n_steps} parts, each has {mega_batch_size} samples")
# with open(mani_fn, "a") as mani_wf: # resume from where we failed
for m, mega_batch in enumerate(loader):
logging.info(f"====================================")
logging.info(f"====================================")
logging.info(f"now processing mega step {m+1}/{mega_n_steps}")
lengths = np.array(mega_batch[1])
sorted_inds = sort_by_audio_len(lengths)
for j in range(len(sorted_inds))[::-1]:
if lengths[sorted_inds[j]] < args.model_sr*0.2 or lengths[sorted_inds[j]] > args.model_sr*args.len_cap: # skip samples that are too short (shorter than 0.2s), or too big (bigger than 80s)
skip += 1
del sorted_inds[j]
n_steps = int(np.ceil(len(sorted_inds) / batch_size))
for n in tqdm.tqdm(range(n_steps), disable=True):
inds_used = sorted_inds[n*batch_size:(n+1)*batch_size]
wav_batch = [mega_batch[0][id] for id in inds_used]
all_lens = [mega_batch[1][id] for id in inds_used]
segment_id_batch = [mega_batch[2][id] for id in inds_used]
# print(segment_id_batch)
padded_wav = torch.nn.utils.rnn.pad_sequence(wav_batch, batch_first=True).unsqueeze(1) # [B, T] -> [B, 1, T]
with torch.no_grad():
if max(all_lens) > 300000 and len(all_lens) > 1: # NOTE decrease this (300000) if OOM, or chunk it into more than 2 forward passes
codes = []
inwav = padded_wav.cuda()
codes.append(model(inwav[:len(inwav)//2], encode=True)[0].cpu())
codes.append(model(inwav[len(inwav)//2:], encode=True)[0].cpu())
codes = torch.cat(codes, dim=0)
else:
encoded_frames = model(padded_wav.cuda(), encode=True) # wav needs to have shape [B, C, T], C is model.channels, which is 1 for the 24kHz encodec model
# logging.info(f"encoded_frames: {encoded_frames[0].shape}")
codes = encoded_frames[0].cpu()
for i, length in enumerate(all_lens):
save_fn = os.path.join(save_codes_dir, segment_id_batch[i]+".txt")
actual_len = round(length / args.downsample_rate) # 320 is downsample rate for this model
cur_code = codes[i].tolist() if type(codes) == list else codes[i, :, :actual_len].tolist()
write_array_to_txt_file(cur_code, save_fn)
# mani_wf.write(f"0\t{segment_id_batch[i]}\t{len(cur_code[0])}\n") # write to manifest file
# if i == 10:
# raise
# break
# logging.info(f"split {split} has {len(gs[split])} samples in total, skipped {skip} due to forbiden words")
logging.info(f"split {split} has {len(loader.dataset)} samples in total, skipped {skip} due to utterance being too long or too short")
# break

View File

@ -54,8 +54,6 @@ class dataset(torch.utils.data.Dataset):
y = [[int(n)+self.args.n_special for n in l] for l in encos]
else:
y = [[int(n) for n in l] for l in encos]
if self.args.training_stage == 1 and not self.args.valle and not (self.args.musicgen or self.args.valle_orig):
y = y[:1]
except Exception as e:
logging.info(f"loading failed for {pf} and {ef}, maybe files don't exist or are corrupted")
logging.info(f"error message: {e}")
@ -141,15 +139,15 @@ class dataset(torch.utils.data.Dataset):
if self.args.pad_x:
res["x"] = torch.stack(out["x"], dim=0)
else:
res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=0 if self.args.sep_special_token else self.args.text_pad_token)
res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=self.args.text_pad_token)
res["x_lens"] = torch.LongTensor(out["x_len"])
if self.args.dynamic_batching:
if out['y'][0].ndim==2:
res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=0 if self.args.sep_special_token else self.args.audio_pad_token)
res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=self.args.audio_pad_token)
res['y'] = res['y'].permute(1,2,0) # T B K -> B K T
else:
assert out['y'][0].ndim==1, out['y'][0].shape
res['y'] = torch.nn.utils.rnn.pad_sequence(out['y'], batch_first=True, padding_value=0 if self.args.sep_special_token else self.args.audio_pad_token)
res['y'] = torch.nn.utils.rnn.pad_sequence(out['y'], batch_first=True, padding_value=self.args.audio_pad_token)
else:
res['y'] = torch.stack(out['y'], dim=0)
res["y_lens"] = torch.LongTensor(out["y_len"])

View File

@ -0,0 +1,206 @@
import argparse
def parse_args():
parser = argparse.ArgumentParser(description="encode the librilight dataset using encodec model")
parser.add_argument("--dataset_size", type=str, default='xs', help='sizes of gigaspeech, xs, s, m, l, xl. we use xl for VoiceCraft training, xs is good for debugging')
parser.add_argument('--download_to', type=str, default="/data/scratch/pyp/datasets/gigaspeech_debug", help="dir where you want the huggingface gigaspeech dataset to be downloaded to")
parser.add_argument('--save_dir', type=str, default="/data/scratch/pyp/datasets/gigaspeech_phn_enc_manifest_debug", help="path to the manifest, phonemes, and encodec codes dirs")
parser.add_argument('--encodec_model_path', type=str, default="/data/scratch/pyp/exp_pyp/audiocraft/encodec/xps/6f79c6a8/checkpoint.th")
parser.add_argument('--n_workers', type=int, default=4, help="Number of parallel worker processes")
parser.add_argument('--mega_batch_size', type=int, default=100, help="Number of samples in each mega batch for multiprocess dataloading")
parser.add_argument('--batch_size', type=int, default=4, help="batch size for encodec encoding, decrease it if OOM. This is the sum of batch size *over each gpu*, so increase it if you are using more gpus")
parser.add_argument('--model_sr', type=int, default=16000, help='encodec input audio sample rate')
parser.add_argument('--downsample_rate', type=int, default=320, help='encodec downsample rate')
parser.add_argument('--model_code_sr', type=int, default=50, help='encodec model code sample rate')
parser.add_argument('--len_cap', type=float, default=35.0, help='will drop audios that are longer than this number')
parser.add_argument('--max_len', type=int, default=30000, help='max length of audio in samples, if exceed, will cut a batch into half to process, decrease this number if OOM on your machine')
return parser.parse_args()
if __name__ == "__main__":
import logging
formatter = (
"%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
args = parse_args()
import os
import numpy as np
import torch
import tqdm
import time
from datasets import load_dataset, DownloadConfig
from tokenizer import TextTokenizer, tokenize_text
# get the path
phn_save_root = os.path.join(args.save_dir, args.dataset_size, "phonemes")
codes_save_root = os.path.join(args.save_dir, args.dataset_size, "encodec_16khz_4codebooks")
vocab_fn = os.path.join(args.save_dir, args.dataset_size, "vocab.txt")
os.makedirs(phn_save_root, exist_ok=True)
os.makedirs(codes_save_root, exist_ok=True)
def sort_by_audio_len(lens):
inds = np.argsort(lens).tolist()
logging.info(f"longest: {lens[inds[-1]]*args.model_code_sr} encodec codes, {lens[inds[-1]]:.2f} sec.")
logging.info(f"shortest: {lens[inds[0]]*args.model_code_sr} encodec codes, {lens[inds[0]]:.2f} sec.")
logging.info(f"median: {lens[inds[len(inds)//2]]*args.model_code_sr} encodec codes, {lens[inds[len(inds)//2]]:.2f} sec.")
logging.info(f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]*args.model_code_sr} encodec codes, {lens[inds[int(len(inds)*0.95)]]:.2f} sec.")
return inds[::-1]
def write_array_to_txt_file(array, filename):
with open(filename, 'w') as f:
for a in array[:-1]:
f.write(' '.join(map(str, a))+'\n')
f.write(' '.join(map(str, array[-1])))
### phonemization
# load tokenizer
# load the encodec model
from audiocraft.solvers import CompressionSolver
model = CompressionSolver.model_from_checkpoint(args.encodec_model_path)
model = model.cuda()
model = model.eval()
text_tokenizer = TextTokenizer()
# https://github.com/SpeechColab/GigaSpeech
# there are only four different punctuations
# need to check whether there are other < started strings
punc2sym = {" <COMMA>": ",", " <PERIOD>": ".", " <QUESTIONMARK>": "?", " <EXCLAMATIONPOINT>": "!"} # note the space in front of each punc name
gar2sym = {"<SIL>": "#%#", "<MUSIC>": "##%", "<NOISE>": "%%#", "<OTHER>":"%#%"} # so that they are savely keep as the original sym when using tokenize_text
punc2sym.update(gar2sym)
word2sym = { "h æ ʃ h ɐ ʃ p ɚ s ɛ n t": "<MUSIC>", "h æ ʃ p ɚ s ɛ n t h æ ʃ": "<SIL>", "p ɚ s ɛ n t h ɐ ʃ p ɚ s ɛ n t": "<OTHER>", "p ɚ s ɛ n t p ɚ s ɛ n t h æ ʃ": "<NOISE>"}
forbidden_words = set(['#%#', '##%', '%%#', '%#%'])
dc = DownloadConfig(cache_dir=args.download_to)
stime = time.time()
logging.info("loading the dataset...")
gs = load_dataset("speechcolab/gigaspeech", args.dataset_size, use_auth_token=True, cache_dir = args.download_to, download_config=dc)
logging.info(f"time spend on loading the dataset: {time.time() - stime:.2f} seconds")
splits = ['validation', 'test', 'train']
logging.info(f"gigaspeech dataset {args.dataset_size} info: {gs}")
logging.info(f"phonemizing...")
phn_vocab = set()
all_lens = []
# you will see a ton of [WARNING] words_mismatch.py:88......, it's not a issue
for split in tqdm.tqdm(splits):
skip = 0
logging.info(f"now processing split {split}...")
for item in tqdm.tqdm(gs[split]):
save_fn = os.path.join(phn_save_root, item['segment_id']+".txt")
text = item['text']
if sum(word in forbidden_words for word in text.split(" ")):
logging.info(f"skip {item['segment_id']}, because it contains forbiden words. It's transcript: {text}")
skip += 1
continue
for k, v in punc2sym.items():
text = text.replace(k, v)
phn = tokenize_text(text_tokenizer, text)
phn_seq = " ".join(phn)
for k, v in word2sym.items():
phn_seq = phn_seq.replace(k, v)
phn_vocab.update(phn_seq.split(" "))
all_lens.append(len(phn_seq.split(" ")))
with open(save_fn, "w") as f:
f.write(phn_seq)
logging.info(f"split {split} has {len(gs[split])} samples in total, skipped {skip} due to forbiden words")
print(f"phn vocab size: {len(list(phn_vocab))}")
print("phn sequence stats: ")
print(f"longest: {max(all_lens)}")
print(f"shortest: {min(all_lens)}")
print(f"median: {np.quantile(all_lens, 0.5)}")
print(f"95 percentile longest: {np.quantile(all_lens, 0.95)}")
print("write vocabulary to ", vocab_fn)
with open(vocab_fn, "w") as f:
for i, phn in enumerate(list(phn_vocab)):
if i < len(list(phn_vocab)) - 1:
f.write(f"{str(i)} {phn}\n")
else:
f.write(f"{str(i)} {phn}")
class mydataset(torch.utils.data.Dataset):
def __init__(self, split):
super().__init__()
self.data = gs[split]
def __len__(self):
return len(self.data)
def __getitem__(self, ind):
try:
segment_id, audio, sr, text, begin_time, end_time = self.data[ind]['segment_id'], torch.from_numpy(self.data[ind]['audio']['array']).float(), self.data[ind]['audio']['sampling_rate'], self.data[ind]['text'], self.data[ind]['begin_time'], self.data[ind]['end_time']
except:
return None, None, None, None, None, None
return segment_id, audio, sr, text, begin_time, end_time
def collate(self, batch):
res = {'segment_id': [], "audio": [], "sr": [], "text": [], "begin_time": [], "end_time": []}
for item in batch:
if item[0] != None:
res['segment_id'].append(item[0])
res['audio'].append(item[1])
res['sr'].append(item[2])
res['text'].append(item[3])
res['begin_time'].append(item[4])
res['end_time'].append(item[5])
return res
## encodec codes extraction
logging.info("encodec encoding...")
train_dataset = mydataset('train')
train_loader = torch.torch.utils.data.DataLoader(train_dataset, batch_size=args.mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=train_dataset.collate)
validation_dataset = mydataset('validation')
validation_loader = torch.torch.utils.data.DataLoader(validation_dataset, batch_size=args.mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=validation_dataset.collate)
test_dataset = mydataset('test')
test_loader = torch.torch.utils.data.DataLoader(test_dataset, batch_size=args.mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=test_dataset.collate)
splits = ['validation', 'test', 'train']
loaders = [validation_loader, test_loader, train_loader]
# splits = ['validation'] # for debug
# loaders = [validation_loader]
for split, loader in zip(splits, loaders):
skip = 0
logging.info(f"now processing split {split}...")
mega_n_steps = int(np.ceil(len(gs[split]) / args.mega_batch_size))
logging.info(f"partition the split {split} into {mega_n_steps} parts, each has {args.mega_batch_size} samples")
for m, mega_batch in enumerate(loader):
logging.info(f"====================================")
logging.info(f"====================================")
logging.info(f"now processing mega step {m+1}/{mega_n_steps}")
lengths = np.array(mega_batch['end_time']) - np.array(mega_batch['begin_time'])
sorted_inds = sort_by_audio_len(lengths)
for j in range(len(sorted_inds))[::-1]:
if lengths[sorted_inds[j]] < 0.2 or lengths[sorted_inds[j]] > args.len_cap: # skip samples that are too short (shorter than 0.2s), or too big (bigger than 80s)
skip += 1
del sorted_inds[j]
n_steps = int(np.ceil(len(sorted_inds) / args.batch_size))
for n in tqdm.tqdm(range(n_steps), disable=True):
inds_used = sorted_inds[n*args.batch_size:(n+1)*args.batch_size]
audio_batch = [mega_batch['audio'][id] for id in inds_used]
sr_batch = [mega_batch['sr'][id] for id in inds_used]
segment_id_batch = [mega_batch['segment_id'][id] for id in inds_used]
text_batch = [mega_batch['text'][id] for id in inds_used]
padded_wav = torch.nn.utils.rnn.pad_sequence(audio_batch, batch_first=True).unsqueeze(1) # [B, T] -> [B, 1, T]
all_lens = [lengths[id] for id in inds_used]
with torch.no_grad():
if max(all_lens) > args.max_len and len(all_lens) > 1: # NOTE decrease args.max_len if OOM, or chunk it into more than 2 forward passes
codes = []
inwav = padded_wav.cuda()
codes.append(model.encode(inwav[:len(inwav)//2])[0].cpu())
codes.append(model.encode(inwav[len(inwav)//2:])[0].cpu())
codes = torch.cat(codes, dim=0)
else:
encoded_frames = model.encode(padded_wav.cuda())
# logging.info(f"encoded_frames: {encoded_frames[0].shape}")
codes = encoded_frames[0].cpu()
for i, length in enumerate(all_lens):
save_fn = os.path.join(codes_save_root, segment_id_batch[i]+".txt")
actual_len = round(length * args.model_code_sr) # 320 is downsample rate for this model
cur_code = codes[i].tolist() if type(codes) == list else codes[i, :, :actual_len].tolist()
write_array_to_txt_file(cur_code, save_fn)

417
environment.yml Normal file
View File

@ -0,0 +1,417 @@
name: voicecraft
channels:
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=2_gnu
- aom=3.8.2=h59595ed_0
- asttokens=2.4.1=pyhd8ed1ab_0
- atk-1.0=2.38.0=hd4edc92_1
- audioread=3.0.1=py39hf3d152e_1
- backcall=0.2.0=pyh9f0ad1d_0
- baumwelch=0.3.7=h00ab1b0_5
- biopython=1.79=py39hb9d737c_3
- brotli=1.1.0=hd590300_1
- brotli-bin=1.1.0=hd590300_1
- brotli-python=1.1.0=py39h3d6467e_1
- bzip2=1.0.8=hd590300_5
- ca-certificates=2024.2.2=hbcca054_0
- cairo=1.18.0=h3faef2a_0
- certifi=2024.2.2=pyhd8ed1ab_0
- cffi=1.16.0=py39h7a31438_0
- charset-normalizer=3.3.2=pyhd8ed1ab_0
- click=8.1.7=unix_pyh707e725_0
- colorama=0.4.6=pyhd8ed1ab_0
- comm=0.2.2=pyhd8ed1ab_0
- contourpy=1.2.0=py39h7633fee_0
- cycler=0.12.1=pyhd8ed1ab_0
- dataclassy=1.0.1=pyhd8ed1ab_0
- dav1d=1.2.1=hd590300_0
- debugpy=1.8.1=py39h3d6467e_0
- decorator=5.1.1=pyhd8ed1ab_0
- executing=2.0.1=pyhd8ed1ab_0
- expat=2.6.2=h59595ed_0
- ffmpeg=6.1.1=gpl_h38e077a_106
- font-ttf-dejavu-sans-mono=2.37=hab24e00_0
- font-ttf-inconsolata=3.000=h77eed37_0
- font-ttf-source-code-pro=2.038=h77eed37_0
- font-ttf-ubuntu=0.83=h77eed37_1
- fontconfig=2.14.2=h14ed4e7_0
- fonts-conda-ecosystem=1=0
- fonts-conda-forge=1=0
- fonttools=4.49.0=py39hd1e30aa_0
- freetype=2.12.1=h267a509_2
- fribidi=1.0.10=h36c2ea0_0
- gdk-pixbuf=2.42.10=h829c605_5
- gettext=0.21.1=h27087fc_0
- giflib=5.2.1=h0b41bf4_3
- gmp=6.3.0=h59595ed_1
- gnutls=3.7.9=hb077bed_0
- graphite2=1.3.13=h58526e2_1001
- graphviz=9.0.0=h78e8752_1
- greenlet=3.0.3=py39h3d6467e_0
- gtk2=2.24.33=h280cfa0_4
- gts=0.7.6=h977cf35_4
- harfbuzz=8.3.0=h3d44ed6_0
- hdbscan=0.8.33=py39h44dd56e_4
- icu=73.2=h59595ed_0
- idna=3.6=pyhd8ed1ab_0
- importlib-metadata=7.0.2=pyha770c72_0
- importlib-resources=6.3.0=pyhd8ed1ab_0
- importlib_metadata=7.0.2=hd8ed1ab_0
- importlib_resources=6.3.0=pyhd8ed1ab_0
- ipykernel=6.29.3=pyhd33586a_0
- jedi=0.19.1=pyhd8ed1ab_0
- joblib=1.3.2=pyhd8ed1ab_0
- jupyter_client=8.6.1=pyhd8ed1ab_0
- jupyter_core=5.7.2=py39hf3d152e_0
- kaldi=5.5.1068=cpu_h31769b2_2
- keyutils=1.6.1=h166bdaf_0
- kiwisolver=1.4.5=py39h7633fee_1
- kneed=0.8.5=pyhd8ed1ab_0
- krb5=1.21.2=h659d440_0
- lame=3.100=h166bdaf_1003
- lazy_loader=0.3=pyhd8ed1ab_0
- lcms2=2.16=hb7c19ff_0
- ld_impl_linux-64=2.40=h41732ed_0
- lerc=4.0.0=h27087fc_0
- libabseil=20240116.1=cxx17_h59595ed_2
- libass=0.17.1=h8fe9dca_1
- libblas=3.9.0=21_linux64_openblas
- libbrotlicommon=1.1.0=hd590300_1
- libbrotlidec=1.1.0=hd590300_1
- libbrotlienc=1.1.0=hd590300_1
- libcblas=3.9.0=21_linux64_openblas
- libclang-cpp15=15.0.7=default_hb11cfb5_4
- libdeflate=1.19=hd590300_0
- libdrm=2.4.120=hd590300_0
- libedit=3.1.20191231=he28a2e2_2
- libexpat=2.6.2=h59595ed_0
- libffi=3.4.2=h7f98852_5
- libflac=1.4.3=h59595ed_0
- libgcc-ng=13.2.0=h807b86a_5
- libgd=2.3.3=h119a65a_9
- libgfortran-ng=13.2.0=h69a702a_5
- libgfortran5=13.2.0=ha4646dd_5
- libglib=2.80.0=hf2295e7_0
- libgomp=13.2.0=h807b86a_5
- libhwloc=2.9.3=default_h554bfaf_1009
- libiconv=1.17=hd590300_2
- libidn2=2.3.7=hd590300_0
- libjpeg-turbo=3.0.0=hd590300_1
- liblapack=3.9.0=21_linux64_openblas
- liblapacke=3.9.0=21_linux64_openblas
- libllvm14=14.0.6=hcd5def8_4
- libllvm15=15.0.7=hb3ce162_4
- libllvmspirv15=15.0.0=h0cdce71_1
- libnsl=2.0.1=hd590300_0
- libogg=1.3.4=h7f98852_1
- libopenblas=0.3.26=pthreads_h413a1c8_0
- libopenvino=2024.0.0=h2e90f83_1
- libopenvino-auto-batch-plugin=2024.0.0=hd5fc58b_1
- libopenvino-auto-plugin=2024.0.0=hd5fc58b_1
- libopenvino-hetero-plugin=2024.0.0=h3ecfda7_1
- libopenvino-intel-cpu-plugin=2024.0.0=h2e90f83_1
- libopenvino-intel-gpu-plugin=2024.0.0=h2e90f83_1
- libopenvino-ir-frontend=2024.0.0=h3ecfda7_1
- libopenvino-onnx-frontend=2024.0.0=h757c851_1
- libopenvino-paddle-frontend=2024.0.0=h757c851_1
- libopenvino-pytorch-frontend=2024.0.0=h59595ed_1
- libopenvino-tensorflow-frontend=2024.0.0=hca94c1a_1
- libopenvino-tensorflow-lite-frontend=2024.0.0=h59595ed_1
- libopus=1.3.1=h7f98852_1
- libpciaccess=0.18=hd590300_0
- libpng=1.6.43=h2797004_0
- libpq=16.2=h33b98f1_0
- libprotobuf=4.25.3=h08a7969_0
- librosa=0.10.1=pyhd8ed1ab_0
- librsvg=2.56.3=he3f83f7_1
- libsndfile=1.2.2=hc60ed4a_1
- libsodium=1.0.18=h36c2ea0_1
- libsqlite=3.45.2=h2797004_0
- libstdcxx-ng=13.2.0=h7e041cc_5
- libtasn1=4.19.0=h166bdaf_0
- libtiff=4.6.0=ha9c0a0a_2
- libunistring=0.9.10=h7f98852_0
- libuuid=2.38.1=h0b41bf4_0
- libva=2.21.0=hd590300_0
- libvorbis=1.3.7=h9c3ff4c_0
- libvpx=1.14.0=h59595ed_0
- libwebp=1.3.2=h658648e_1
- libwebp-base=1.3.2=hd590300_0
- libxcb=1.15=h0b41bf4_0
- libxcrypt=4.4.36=hd590300_1
- libxml2=2.12.5=h232c23b_0
- libzlib=1.2.13=hd590300_5
- llvm-spirv-15=15.0.0=h0cdce71_1
- mad=0.15.1b=h9c3ff4c_1
- markdown-it-py=3.0.0=pyhd8ed1ab_0
- matplotlib-base=3.8.3=py39he9076e7_0
- matplotlib-inline=0.1.6=pyhd8ed1ab_0
- mdurl=0.1.2=pyhd8ed1ab_0
- montreal-forced-aligner=2.2.17=pyhd8ed1ab_0
- mpg123=1.32.4=h59595ed_0
- msgpack-python=1.0.7=py39h7633fee_0
- munkres=1.1.4=pyh9f0ad1d_0
- ncurses=6.4=h59595ed_2
- nest-asyncio=1.6.0=pyhd8ed1ab_0
- nettle=3.9.1=h7ab15ed_0
- ngram=1.3.14=h924138e_2
- numba=0.59.0=py39h615d6bd_1
- numpy=1.26.4=py39h474f0d3_0
- ocl-icd=2.3.2=hd590300_0
- openfst=1.8.2=h924138e_2
- openh264=2.4.1=h59595ed_0
- openjpeg=2.5.2=h488ebb8_0
- openssl=3.2.1=hd590300_0
- p11-kit=0.24.1=hc5aa10d_0
- packaging=24.0=pyhd8ed1ab_0
- pandas=2.2.1=py39hddac248_0
- pango=1.52.1=ha41ecd1_0
- parso=0.8.3=pyhd8ed1ab_0
- patsy=0.5.6=pyhd8ed1ab_0
- pcre2=10.43=hcad00b1_0
- pexpect=4.9.0=pyhd8ed1ab_0
- pgvector-python=0.2.5=pyhe093146_0
- pickleshare=0.7.5=py_1003
- pillow=10.2.0=py39had0adad_0
- pip=24.0=pyhd8ed1ab_0
- pixman=0.43.2=h59595ed_0
- platformdirs=4.2.0=pyhd8ed1ab_0
- pocl=5.0=h03a6ac1_2
- pocl-core=5.0=hdaecddf_2
- pocl-cpu=5.0=he901f76_2
- pocl-cpu-minimal=5.0=h5ccd973_2
- pocl-cuda=5.0=hdaecddf_2
- pocl-remote=5.0=h5ccd973_2
- pooch=1.8.1=pyhd8ed1ab_0
- postgresql=16.2=h7387d8b_0
- prompt-toolkit=3.0.42=pyha770c72_0
- prompt_toolkit=3.0.42=hd8ed1ab_0
- psutil=5.9.8=py39hd1e30aa_0
- psycopg2=2.9.9=py39h89197e3_0
- pthread-stubs=0.4=h36c2ea0_1001
- ptyprocess=0.7.0=pyhd3deb0d_0
- pugixml=1.14=h59595ed_0
- pure_eval=0.2.2=pyhd8ed1ab_0
- pycparser=2.21=pyhd8ed1ab_0
- pygments=2.17.2=pyhd8ed1ab_0
- pyparsing=3.1.2=pyhd8ed1ab_0
- pysocks=1.7.1=pyha2e5f31_6
- pysoundfile=0.12.1=pypyhd8ed1ab_1
- python=3.9.18=h0755675_1_cpython
- python-tzdata=2024.1=pyhd8ed1ab_0
- python_abi=3.9=4_cp39
- pytz=2024.1=pyhd8ed1ab_0
- pyyaml=6.0.1=py39hd1e30aa_1
- pyzmq=25.1.2=py39h8c080ef_0
- readline=8.2=h8228510_1
- requests=2.31.0=pyhd8ed1ab_0
- rich=13.7.1=pyhd8ed1ab_0
- rich-click=1.7.4=pyhd8ed1ab_0
- scikit-learn=1.2.2=py39hc236052_2
- scipy=1.12.0=py39h474f0d3_2
- seaborn=0.13.2=hd8ed1ab_0
- seaborn-base=0.13.2=pyhd8ed1ab_0
- setuptools=69.2.0=pyhd8ed1ab_0
- six=1.16.0=pyh6c4a22f_0
- snappy=1.1.10=h9fff704_0
- sox=14.4.2=ha5cc309_1018
- soxr=0.1.3=h0b41bf4_3
- soxr-python=0.3.7=py39h44dd56e_0
- sqlalchemy=2.0.28=py39hd1e30aa_0
- sqlite=3.45.2=h2c6b66d_0
- stack_data=0.6.2=pyhd8ed1ab_0
- statsmodels=0.14.1=py39h44dd56e_0
- svt-av1=1.8.0=h59595ed_0
- tbb=2021.11.0=h00ab1b0_1
- threadpoolctl=3.3.0=pyhc1e730c_0
- tk=8.6.13=noxft_h4845f30_101
- tornado=6.4=py39hd1e30aa_0
- tqdm=4.66.2=pyhd8ed1ab_0
- traitlets=5.14.2=pyhd8ed1ab_0
- typing-extensions=4.10.0=hd8ed1ab_0
- typing_extensions=4.10.0=pyha770c72_0
- tzcode=2024a=h3f72095_0
- tzdata=2024a=h0c530f3_0
- unicodedata2=15.1.0=py39hd1e30aa_0
- urllib3=2.2.1=pyhd8ed1ab_0
- wcwidth=0.2.13=pyhd8ed1ab_0
- wheel=0.42.0=pyhd8ed1ab_0
- x264=1!164.3095=h166bdaf_2
- x265=3.5=h924138e_3
- xorg-fixesproto=5.0=h7f98852_1002
- xorg-kbproto=1.0.7=h7f98852_1002
- xorg-libice=1.1.1=hd590300_0
- xorg-libsm=1.2.4=h7391055_0
- xorg-libx11=1.8.7=h8ee46fc_0
- xorg-libxau=1.0.11=hd590300_0
- xorg-libxdmcp=1.1.3=h7f98852_0
- xorg-libxext=1.3.4=h0b41bf4_2
- xorg-libxfixes=5.0.3=h7f98852_1004
- xorg-libxrender=0.9.11=hd590300_0
- xorg-renderproto=0.11.1=h7f98852_1002
- xorg-xextproto=7.3.0=h0b41bf4_1003
- xorg-xproto=7.0.31=h7f98852_1007
- xz=5.2.6=h166bdaf_0
- yaml=0.2.5=h7f98852_2
- zeromq=4.3.5=h59595ed_1
- zipp=3.17.0=pyhd8ed1ab_0
- zlib=1.2.13=hd590300_5
- zstd=1.5.5=hfc55251_0
- pip:
- absl-py==2.1.0
- aiofiles==23.2.1
- aiohttp==3.9.3
- aiosignal==1.3.1
- altair==5.2.0
- antlr4-python3-runtime==4.9.3
- anyio==4.3.0
- async-timeout==4.0.3
- attrs==23.2.0
- av==11.0.0
- babel==2.14.0
- beautifulsoup4==4.12.3
- bibtexparser==2.0.0b7
- bleach==6.1.0
- blis==0.7.11
- catalogue==2.0.10
- clldutils==3.22.2
- cloudpickle==3.0.0
- cmake==3.28.3
- colorlog==6.8.2
- confection==0.1.4
- csvw==3.3.0
- cymem==2.0.8
- cython==0.29.37
- datasets==2.16.0
- defusedxml==0.7.1
- demucs==4.0.1
- dill==0.3.6
- dlinfo==1.2.1
- docopt==0.6.2
- dora-search==0.1.12
- einops==0.7.0
- encodec==0.1.1
- exceptiongroup==1.2.0
- fastapi==0.110.0
- fastjsonschema==2.19.1
- ffmpy==0.3.2
- filelock==3.13.1
- flashy==0.0.2
- frozenlist==1.4.1
- fsspec==2023.10.0
- gradio==3.50.2
- gradio-client==0.6.1
- grpcio==1.62.1
- h11==0.14.0
- httpcore==1.0.4
- httpx==0.27.0
- huggingface-hub==0.21.4
- hydra-colorlog==1.2.0
- hydra-core==1.3.2
- ipython==8.12.3
- isodate==0.6.1
- jinja2==3.1.3
- jsonschema==4.21.1
- jsonschema-specifications==2023.12.1
- julius==0.2.7
- jupyterlab-pygments==0.3.0
- lameenc==1.7.0
- langcodes==3.3.0
- language-tags==1.2.0
- lit==18.1.1
- llvmlite==0.42.0
- lxml==5.1.0
- markdown==3.5.2
- markupsafe==2.1.5
- mistune==3.0.2
- mpmath==1.3.0
- msgpack==1.0.8
- multidict==6.0.5
- multiprocess==0.70.14
- murmurhash==1.0.10
- nbclient==0.10.0
- nbconvert==7.16.3
- nbformat==5.10.3
- networkx==3.2.1
- num2words==0.5.13
- nvidia-cublas-cu11==11.10.3.66
- nvidia-cuda-cupti-cu11==11.7.101
- nvidia-cuda-nvrtc-cu11==11.7.99
- nvidia-cuda-runtime-cu11==11.7.99
- nvidia-cudnn-cu11==8.5.0.96
- nvidia-cufft-cu11==10.9.0.58
- nvidia-curand-cu11==10.2.10.91
- nvidia-cusolver-cu11==11.4.0.1
- nvidia-cusparse-cu11==11.7.4.91
- nvidia-nccl-cu11==2.14.3
- nvidia-nvtx-cu11==11.7.91
- omegaconf==2.3.0
- openunmix==1.2.1
- orjson==3.9.15
- pandocfilters==1.5.1
- pathlib-abc==0.1.1
- pathy==0.11.0
- pgvector==0.2.2
- phonemizer==3.2.1
- pipreqs==0.5.0
- praatio==6.2.0
- preshed==3.0.9
- protobuf==4.25.3
- pyarrow==15.0.2
- pyarrow-hotfix==0.6
- pydantic==1.10.14
- pydub==0.25.1
- pylatexenc==2.10
- pynini==2.1.6
- pypinyin==0.48.0
- python-dateutil==2.9.0.post0
- python-multipart==0.0.9
- rdflib==7.0.0
- referencing==0.33.0
- regex==2023.12.25
- responses==0.18.0
- retrying==1.3.4
- rfc3986==1.5.0
- rpds-py==0.18.0
- safetensors==0.4.2
- segments==2.2.1
- semantic-version==2.10.0
- sentencepiece==0.2.0
- smart-open==6.4.0
- sniffio==1.3.1
- soupsieve==2.5
- spacy==3.5.2
- spacy-legacy==3.0.12
- spacy-loggers==1.0.5
- srsly==2.4.8
- starlette==0.36.3
- submitit==1.5.1
- sympy==1.12
- tabulate==0.9.0
- tensorboard==2.16.2
- tensorboard-data-server==0.7.2
- thinc==8.1.12
- tinycss2==1.2.1
- tokenizers==0.15.2
- toolz==0.12.1
- torch==2.0.1
- torchaudio==2.0.2
- torchmetrics==0.11.1
- transformers==4.38.2
- treetable==0.2.5
- triton==2.0.0
- typer==0.7.0
- uritemplate==4.1.1
- uvicorn==0.28.0
- wasabi==1.1.2
- webencodings==0.5.1
- websockets==11.0.3
- werkzeug==3.0.1
- xformers==0.0.22
- xxhash==3.4.1
- yarg==0.1.9
- yarl==1.9.4
prefix: /home/pyp/miniconda3/envs/voicecraft

View File

@ -504,7 +504,7 @@ class VoiceCraft(nn.Module):
ntokens = []
top10acc = []
for k, (logit, target) in enumerate(zip(logits, targets)):
loss.append(F.cross_entropy(logit, target, reduction='mean', weight=self.class_weight.data if self.args.eog_weight!=1 else None))
loss.append(F.cross_entropy(logit, target, reduction='mean'))
top10acc.append(self.accuracy_metrics[k](logit.detach(), target))
ntokens.append(len(logit))
@ -988,6 +988,8 @@ class VoiceCraft(nn.Module):
for jj in range(1,self.args.n_codebooks):
logits_adjust[jj][eog_inference] = -10000
logits_adjust[jj][self.args.empty_token] = -10000
if cur_num_gen <= self.args.encodec_sr // 5: # this shouldn't happen, but just in case the model stopped too early
logits_adjust[0][eog_inference] = -10000
##################### silence repetition handling #####################
if stop_repetition > 0 and prev_token in silence_tokens and consec_silence_count > stop_repetition:
if logits_adjust[0, prev_token] < 0:
@ -1237,6 +1239,8 @@ class VoiceCraft(nn.Module):
for jj in range(1,self.args.n_codebooks):
logits_adjust[:,jj,eog_inference] = -10000
logits_adjust[:,jj,self.args.empty_token] = -10000
if cur_num_gen <= self.args.encodec_sr // 5: # this shouldn't happen, but just in case the model stopped too early
logits_adjust[:,:,eog_inference] = -10000
##################### silence repetition handling #####################
for b in range(batch_size):
prev_token = prev_tokens[b]

View File

@ -7,9 +7,9 @@ export WORLD_SIZE=4
dataset=gigaspeech
mkdir -p ./logs/${dataset}
exp_root="/data/scratch/pyp/exp_pyp/VoiceCraft"
exp_root="path/to/store/exp_results"
exp_name=e830M
dataset_dir="/data/scratch/pyp/datasets/gigaspeech_phn_enc_manifest/xl"
dataset_dir="path/to/stored_extracted_codes_and_phonemes/xl" # xs if you only extracted xs in previous step
encodec_codes_folder_name="encodec_16khz_4codebooks"
# export CUDA_LAUNCH_BLOCKING=1 # for debugging
@ -51,7 +51,7 @@ torchrun --nnodes=1 --rdzv-backend=c10d --rdzv-endpoint=localhost:41977 --nproc_
--text_vocab_size 100 \
--text_pad_token 100 \
--phn_folder_name "phonemes" \
--manifest_name "manifest_large16khz_lessambi" \
--manifest_name "manifest" \
--encodec_folder_name ${encodec_codes_folder_name} \
--audio_vocab_size 2048 \
--empty_token 2048 \