diff --git a/README.md b/README.md index 1e13a32..26d4c1b 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # VoiceCraft: Zero-Shot Speech Editing and Text-to-Speech in the Wild [Demo](https://jasonppy.github.io/VoiceCraft_web) [Paper](https://jasonppy.github.io/assets/pdfs/VoiceCraft.pdf) -TL;DR: +### TL;DR VoiceCraft is a token infilling neural codec language model, that achieves state-of-the-art performance on both **speech editing** and **zero-shot text-to-speech (TTS)** on in-the-wild data including audiobooks, internet videos, and podcasts. To clone or edit an unseen voice, VoiceCraft needs only a few seconds of reference. @@ -12,22 +12,25 @@ The TODOs left will be completed by the end of March 2024. - [x] Codebase upload - [x] Environment setup - [x] Inference demo for speech editing and TTS -- [ ] Upload model weights -- [ ] Training guidance -- [ ] Upload the RealEdit dataset +- [x] Training guidance +- [x] Upload the RealEdit dataset and training manifest +- [ ] Upload model weights (encodec weights are up) + ## Environment setup ```bash conda create -n voicecraft python=3.9.16 conda activate voicecraft -pip install torch==2.0.1 torchaudio==2.0.2 # this assumes your system is compatible with CUDA 11.7, otherwise checkout https://pytorch.org/get-started/previous-versions/#v201 +pip install torch==2.0.1 # this assumes your system is compatible with CUDA 11.7, otherwise checkout https://pytorch.org/get-started/previous-versions/#v201 apt-get install ffmpeg # if you don't already have ffmpeg installed pip install -e git+https://github.com/facebookresearch/audiocraft.git@c5157b5bf14bf83449c17ea1eeb66c19fb4bc7f0#egg=audiocraft apt-get install espeak-ng # backend for the phonemizer installed below +pip install tensorboard=2.16.2 pip install phonemizer==3.2.1 -pip install tensorboard -pip install datasets==2.12.0 +pip install torchaudio==2.0.2 +pip install datasets==2.16.0 +pip install torchmetrics==0.11.1 # install MFA for getting forced-alignment, this could take a few minutes conda install -c conda-forge montreal-forced-aligner=2.2.17 openfst=1.8.2 kaldi=5.5.1068 # conda install pocl # above gives an warning for installing pocl, not sure if really need this @@ -36,9 +39,51 @@ conda install -c conda-forge montreal-forced-aligner=2.2.17 openfst=1.8.2 kaldi= conda install -n voicecraft ipykernel --update-deps --force-reinstall ``` +If you have encountered version issues when running things, checkout [environment.yml](./environment.yml) for exact matching. + ## Inference Examples Checkout [`inference_speech_editing.ipynb`](./inference_speech_editing.ipynb) and [`inference_tts.ipynb`](./inference_tts.ipynb) +## Training +To train an VoiceCraft model, you need to prepare the following parts: +1. utterances and their transcripts +2. encode the utterances into codes using e.g. Encodec +3. convert transcripts into phoneme sequence, and a phoneme set (we named it vocab.txt) +4. manifest (i.e. metadata) + +Step 1,2,3 are handled in [./data/phonemize_encodec_encode_hf.py](./data/phonemize_encodec_encode_hf.py), where +1. Gigaspeech is downloaded through HuggingFace. Note that you need to sign an agreement in order to download the dataset (it needs your auth token) +2. phoneme sequence and encodec codes are also extracted using the script. + +An example run: + +```bash +conda activate voicecraft +export CUDA_VISIBLE_DEVICES=0 +cd ./data +python phonemize_encodec_encode_hf.py \ +--dataset_size xs \ +--download_to path/to/store_huggingface_downloads \ +--save_dir path/to/store_extracted_codes_and_phonemes \ +--encodec_model_path path/to/encodec_model \ +--mega_batch_size 120 \ +--batch_size 32 \ +--max_len 30000 +``` +where encodec_model_path is avaliable [here](https://huggingface.co/pyp1/VoiceCraft). This model is trained on Gigaspeech XL, it has 56M parameters, 4 codebooks, each codebook has 2048 codes. Details are described in our [paper](https://jasonppy.github.io/assets/pdfs/VoiceCraft.pdf). If you encounter OOM during extraction, try decrease the batch_size and/or max_len. +The extracted codes, phonemes, and vocab.txt will be stored at `path/to/store_extracted_codes_and_phonemes/${dataset_size}/{encodec_16khz_4codebooks,phonemes,vocab.txt}`. + +As for manifest, please download train.txt and validation.txt from [here](https://huggingface.co/datasets/pyp1/VoiceCraft_RealEdit/tree/main), and put them under `path/to/store_extracted_codes_and_phonemes/manifest/`. Please also download vocab.txt from [here](https://huggingface.co/datasets/pyp1/VoiceCraft_RealEdit/tree/main) if you want to use our pretrained VoiceCraft model (so that the phoneme-to-token matching is the same). + +Now, you are good to start training! + +```bash +conda activate voicecraft +cd ./z_scripts +bash e830M.sh +``` + + ## License The codebase is under CC BY-NC-SA 4.0 ([LICENSE-CODE](./LICENSE-CODE)), and the model weights are under Coqui Public Model License 1.0.0 ([LICENSE-MODEL](./LICENSE-MODEL)). Note that we use some of the code from other repository that are under different licenses: `./models/codebooks_patterns.py` is under MIT license; `./models/modules`, `./steps/optim.py`, `data/tokenizer.py` are under Apache License, Version 2.0; the phonemizer we used is under GNU 3.0 License. For drop-in replacement of the phonemizer (i.e. text to IPA phoneme mapping), try [g2p](https://github.com/roedoejet/g2p) (MIT License) or [OpenPhonemizer](https://github.com/NeuralVox/OpenPhonemizer) (BSD-3-Clause Clear), although these are not tested. diff --git a/data/giga_preprocessing/encodec_encode.py b/data/giga_preprocessing/encodec_encode.py deleted file mode 100644 index f2a9915..0000000 --- a/data/giga_preprocessing/encodec_encode.py +++ /dev/null @@ -1,160 +0,0 @@ -import argparse -def parse_args(): - parser = argparse.ArgumentParser(description="encode the librilight dataset using encodec model") - parser.add_argument("--manifest_root", type=str, default="/home/pyp/audiocraft/egs/gigaspeech", help="this the dir of the audiocraft manifest!") - parser.add_argument('--audio_dir', type=str, default="/data/scratch/pyp/datasets/gigaspeech_flac", help="Path dirs of the flac audio files") - parser.add_argument('--save_dir', type=str, default="/data/scratch/pyp/datasets/gigaspeech_phn_enc_manifest/xl", help="path to the manifest, phonemes, and encodec codes dirs") - parser.add_argument('--encodec_model_path', type=str, default="/data/scratch/pyp/exp_pyp/audiocraft/encodec/xps/6f79c6a8/checkpoint.th") - parser.add_argument('--n_workers', type=int, default=32, help="Number of parallel worker processes") - parser.add_argument('--batch_size', type=int, default=64, help="batch size for encodec encoding, decrease it if OOM. This is the sum of batch size *over each gpu*, so increase it if you are using more gpus") - parser.add_argument('--model_sr', type=int, default=16000, help='encodec input audio sample rate') - parser.add_argument('--downsample_rate', type=int, default=320, help='encodec downsample rate') - parser.add_argument('--model_code_sr', type=int, default=50, help='encodec model code sample rate') - parser.add_argument('--len_cap', type=float, default=35.0, help='will drop audios that are longer than this number') - return parser.parse_args() - -if __name__ == "__main__": - import logging - formatter = ( - "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s" - ) - logging.basicConfig(format=formatter, level=logging.INFO) - - import os - import numpy as np - import torch - import torchaudio - import tqdm - import time - - args = parse_args() - - manifest_dir = args.manifest_root # this dir is scp-ed - audio_dir = args.audio_dir # this is scp-ed flac dir - encodec_signature = args.encodec_model_path.split("/")[-2] - save_codes_dir = os.path.join(args.save_dir, f"encodec_16khz_{encodec_signature}") - os.makedirs(save_codes_dir, exist_ok=True) - - - # model_sr = 16000 - # downsample_rate = 320 - # model_code_sr = 50 - def sort_by_audio_len(lens): - inds = np.argsort(lens).tolist() - logging.info(f"longest: {lens[inds[-1]]/args.downsample_rate} encodec codes, {lens[inds[-1]]/args.model_sr:.2f} sec.") - logging.info(f"shortest: {lens[inds[0]]/args.downsample_rate} encodec codes, {lens[inds[0]]/args.model_sr:.2f} sec.") - logging.info(f"median: {lens[inds[len(inds)//2]]/args.downsample_rate} encodec codes, {lens[inds[len(inds)//2]]/args.model_sr:.2f} sec.") - logging.info(f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]/args.downsample_rate} encodec codes, {lens[inds[int(len(inds)*0.95)]]/args.model_sr:.2f} sec.") - return inds[::-1] - - def write_array_to_txt_file(array, filename): - with open(filename, 'w') as f: - for a in array[:-1]: - f.write(' '.join(map(str, a))+'\n') - f.write(' '.join(map(str, array[-1]))) - - - - class mydataset(torch.utils.data.Dataset): - def __init__(self, split): - super().__init__() - # self.data = gs[split] - self.split = split - self.audio_root = audio_dir - manifest_fn = os.path.join(manifest_dir, split+".txt") - with open(manifest_fn, "r") as rf: - self.data = [l.strip().split("\t") for l in rf.readlines()] - def __len__(self): - return len(self.data) - def __getitem__(self, ind): - try: - afn = self.data[ind][0] - fn = os.path.join(self.audio_root, afn) - audio, sr = torchaudio.load(fn) - assert sr == args.model_sr, sr - except Exception as e: - logging.info(f"{e}") - return None, None, None - assert audio.ndim==2 and audio.shape[0] == 1, audio.shape - return audio.type(torch.float32).squeeze(0), audio.shape[-1], os.path.basename(afn).split(".")[0] - def collate(self, batch): - lens, audios, segment_ids = [], [], [] - for item in batch: - if item[0] != None: - audios.append(item[0]) - lens.append(item[1]) - segment_ids.append(item[2]) - return audios, lens, segment_ids - - # load the encodec model - from audiocraft.solvers import CompressionSolver - model = CompressionSolver.model_from_checkpoint(args.encodec_model_path) - model = model.cuda() - model = model.eval() - model = torch.nn.DataParallel(model) - - - # setup dataloader - mega_batch_size = 2100 - batch_size = args.batch_size - train_dataset = mydataset('train') - train_loader = torch.torch.utils.data.DataLoader(train_dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=train_dataset.collate) - validation_dataset = mydataset('validation') - validation_loader = torch.torch.utils.data.DataLoader(validation_dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=validation_dataset.collate) - test_dataset = mydataset('test') - test_loader = torch.torch.utils.data.DataLoader(test_dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=test_dataset.collate) - splits = ['validation', 'test', 'train'] - loaders = [validation_loader, test_loader, train_loader] - # splits = ['validation'] # NOTE this is for debug, for example, see if the - # loaders = [validation_loader] - for split, loader in zip(splits, loaders): - skip = 0 - logging.info(f"now processing split {split}...") - mega_n_steps = int(np.ceil(len(loader.dataset) / mega_batch_size)) - # mega_n_steps = int(np.ceil(len(gs) / mega_batch_size)) - logging.info(f"partition the split {split} into {mega_n_steps} parts, each has {mega_batch_size} samples") - # with open(mani_fn, "a") as mani_wf: # resume from where we failed - for m, mega_batch in enumerate(loader): - logging.info(f"====================================") - logging.info(f"====================================") - logging.info(f"now processing mega step {m+1}/{mega_n_steps}") - lengths = np.array(mega_batch[1]) - sorted_inds = sort_by_audio_len(lengths) - for j in range(len(sorted_inds))[::-1]: - if lengths[sorted_inds[j]] < args.model_sr*0.2 or lengths[sorted_inds[j]] > args.model_sr*args.len_cap: # skip samples that are too short (shorter than 0.2s), or too big (bigger than 80s) - skip += 1 - del sorted_inds[j] - - n_steps = int(np.ceil(len(sorted_inds) / batch_size)) - for n in tqdm.tqdm(range(n_steps), disable=True): - inds_used = sorted_inds[n*batch_size:(n+1)*batch_size] - wav_batch = [mega_batch[0][id] for id in inds_used] - all_lens = [mega_batch[1][id] for id in inds_used] - segment_id_batch = [mega_batch[2][id] for id in inds_used] - # print(segment_id_batch) - padded_wav = torch.nn.utils.rnn.pad_sequence(wav_batch, batch_first=True).unsqueeze(1) # [B, T] -> [B, 1, T] - with torch.no_grad(): - if max(all_lens) > 300000 and len(all_lens) > 1: # NOTE decrease this (300000) if OOM, or chunk it into more than 2 forward passes - codes = [] - inwav = padded_wav.cuda() - codes.append(model(inwav[:len(inwav)//2], encode=True)[0].cpu()) - codes.append(model(inwav[len(inwav)//2:], encode=True)[0].cpu()) - codes = torch.cat(codes, dim=0) - else: - encoded_frames = model(padded_wav.cuda(), encode=True) # wav needs to have shape [B, C, T], C is model.channels, which is 1 for the 24kHz encodec model - # logging.info(f"encoded_frames: {encoded_frames[0].shape}") - codes = encoded_frames[0].cpu() - - for i, length in enumerate(all_lens): - save_fn = os.path.join(save_codes_dir, segment_id_batch[i]+".txt") - actual_len = round(length / args.downsample_rate) # 320 is downsample rate for this model - cur_code = codes[i].tolist() if type(codes) == list else codes[i, :, :actual_len].tolist() - write_array_to_txt_file(cur_code, save_fn) - - # mani_wf.write(f"0\t{segment_id_batch[i]}\t{len(cur_code[0])}\n") # write to manifest file - # if i == 10: - # raise - # break - # logging.info(f"split {split} has {len(gs[split])} samples in total, skipped {skip} due to forbiden words") - logging.info(f"split {split} has {len(loader.dataset)} samples in total, skipped {skip} due to utterance being too long or too short") - # break \ No newline at end of file diff --git a/data/gigaspeech.py b/data/gigaspeech.py index 0d855a6..c9cf751 100644 --- a/data/gigaspeech.py +++ b/data/gigaspeech.py @@ -54,8 +54,6 @@ class dataset(torch.utils.data.Dataset): y = [[int(n)+self.args.n_special for n in l] for l in encos] else: y = [[int(n) for n in l] for l in encos] - if self.args.training_stage == 1 and not self.args.valle and not (self.args.musicgen or self.args.valle_orig): - y = y[:1] except Exception as e: logging.info(f"loading failed for {pf} and {ef}, maybe files don't exist or are corrupted") logging.info(f"error message: {e}") @@ -141,15 +139,15 @@ class dataset(torch.utils.data.Dataset): if self.args.pad_x: res["x"] = torch.stack(out["x"], dim=0) else: - res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=0 if self.args.sep_special_token else self.args.text_pad_token) + res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=self.args.text_pad_token) res["x_lens"] = torch.LongTensor(out["x_len"]) if self.args.dynamic_batching: if out['y'][0].ndim==2: - res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=0 if self.args.sep_special_token else self.args.audio_pad_token) + res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=self.args.audio_pad_token) res['y'] = res['y'].permute(1,2,0) # T B K -> B K T else: assert out['y'][0].ndim==1, out['y'][0].shape - res['y'] = torch.nn.utils.rnn.pad_sequence(out['y'], batch_first=True, padding_value=0 if self.args.sep_special_token else self.args.audio_pad_token) + res['y'] = torch.nn.utils.rnn.pad_sequence(out['y'], batch_first=True, padding_value=self.args.audio_pad_token) else: res['y'] = torch.stack(out['y'], dim=0) res["y_lens"] = torch.LongTensor(out["y_len"]) diff --git a/data/phonemize_encodec_encode_hf.py b/data/phonemize_encodec_encode_hf.py new file mode 100644 index 0000000..da09ee6 --- /dev/null +++ b/data/phonemize_encodec_encode_hf.py @@ -0,0 +1,206 @@ +import argparse +def parse_args(): + parser = argparse.ArgumentParser(description="encode the librilight dataset using encodec model") + parser.add_argument("--dataset_size", type=str, default='xs', help='sizes of gigaspeech, xs, s, m, l, xl. we use xl for VoiceCraft training, xs is good for debugging') + parser.add_argument('--download_to', type=str, default="/data/scratch/pyp/datasets/gigaspeech_debug", help="dir where you want the huggingface gigaspeech dataset to be downloaded to") + parser.add_argument('--save_dir', type=str, default="/data/scratch/pyp/datasets/gigaspeech_phn_enc_manifest_debug", help="path to the manifest, phonemes, and encodec codes dirs") + parser.add_argument('--encodec_model_path', type=str, default="/data/scratch/pyp/exp_pyp/audiocraft/encodec/xps/6f79c6a8/checkpoint.th") + parser.add_argument('--n_workers', type=int, default=4, help="Number of parallel worker processes") + parser.add_argument('--mega_batch_size', type=int, default=100, help="Number of samples in each mega batch for multiprocess dataloading") + parser.add_argument('--batch_size', type=int, default=4, help="batch size for encodec encoding, decrease it if OOM. This is the sum of batch size *over each gpu*, so increase it if you are using more gpus") + parser.add_argument('--model_sr', type=int, default=16000, help='encodec input audio sample rate') + parser.add_argument('--downsample_rate', type=int, default=320, help='encodec downsample rate') + parser.add_argument('--model_code_sr', type=int, default=50, help='encodec model code sample rate') + parser.add_argument('--len_cap', type=float, default=35.0, help='will drop audios that are longer than this number') + parser.add_argument('--max_len', type=int, default=30000, help='max length of audio in samples, if exceed, will cut a batch into half to process, decrease this number if OOM on your machine') + return parser.parse_args() +if __name__ == "__main__": + import logging + formatter = ( + "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s" + ) + logging.basicConfig(format=formatter, level=logging.INFO) + args = parse_args() + + import os + import numpy as np + import torch + import tqdm + import time + from datasets import load_dataset, DownloadConfig + + from tokenizer import TextTokenizer, tokenize_text + + # get the path + phn_save_root = os.path.join(args.save_dir, args.dataset_size, "phonemes") + codes_save_root = os.path.join(args.save_dir, args.dataset_size, "encodec_16khz_4codebooks") + vocab_fn = os.path.join(args.save_dir, args.dataset_size, "vocab.txt") + os.makedirs(phn_save_root, exist_ok=True) + os.makedirs(codes_save_root, exist_ok=True) + + + def sort_by_audio_len(lens): + inds = np.argsort(lens).tolist() + logging.info(f"longest: {lens[inds[-1]]*args.model_code_sr} encodec codes, {lens[inds[-1]]:.2f} sec.") + logging.info(f"shortest: {lens[inds[0]]*args.model_code_sr} encodec codes, {lens[inds[0]]:.2f} sec.") + logging.info(f"median: {lens[inds[len(inds)//2]]*args.model_code_sr} encodec codes, {lens[inds[len(inds)//2]]:.2f} sec.") + logging.info(f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]*args.model_code_sr} encodec codes, {lens[inds[int(len(inds)*0.95)]]:.2f} sec.") + return inds[::-1] + + def write_array_to_txt_file(array, filename): + with open(filename, 'w') as f: + for a in array[:-1]: + f.write(' '.join(map(str, a))+'\n') + f.write(' '.join(map(str, array[-1]))) + + + ### phonemization + # load tokenizer + # load the encodec model + from audiocraft.solvers import CompressionSolver + model = CompressionSolver.model_from_checkpoint(args.encodec_model_path) + model = model.cuda() + model = model.eval() + text_tokenizer = TextTokenizer() + + + # https://github.com/SpeechColab/GigaSpeech + # there are only four different punctuations + # need to check whether there are other < started strings + punc2sym = {" ": ",", " ": ".", " ": "?", " ": "!"} # note the space in front of each punc name + gar2sym = {"": "#%#", "": "##%", "": "%%#", "":"%#%"} # so that they are savely keep as the original sym when using tokenize_text + punc2sym.update(gar2sym) + + word2sym = { "h æ ʃ h ɐ ʃ p ɚ s ɛ n t": "", "h æ ʃ p ɚ s ɛ n t h æ ʃ": "", "p ɚ s ɛ n t h ɐ ʃ p ɚ s ɛ n t": "", "p ɚ s ɛ n t p ɚ s ɛ n t h æ ʃ": ""} + forbidden_words = set(['#%#', '##%', '%%#', '%#%']) + + dc = DownloadConfig(cache_dir=args.download_to) + stime = time.time() + logging.info("loading the dataset...") + gs = load_dataset("speechcolab/gigaspeech", args.dataset_size, use_auth_token=True, cache_dir = args.download_to, download_config=dc) + logging.info(f"time spend on loading the dataset: {time.time() - stime:.2f} seconds") + + splits = ['validation', 'test', 'train'] + + logging.info(f"gigaspeech dataset {args.dataset_size} info: {gs}") + logging.info(f"phonemizing...") + phn_vocab = set() + all_lens = [] + + # you will see a ton of [WARNING] words_mismatch.py:88......, it's not a issue + for split in tqdm.tqdm(splits): + skip = 0 + logging.info(f"now processing split {split}...") + for item in tqdm.tqdm(gs[split]): + save_fn = os.path.join(phn_save_root, item['segment_id']+".txt") + text = item['text'] + if sum(word in forbidden_words for word in text.split(" ")): + logging.info(f"skip {item['segment_id']}, because it contains forbiden words. It's transcript: {text}") + skip += 1 + continue + for k, v in punc2sym.items(): + text = text.replace(k, v) + phn = tokenize_text(text_tokenizer, text) + phn_seq = " ".join(phn) + for k, v in word2sym.items(): + phn_seq = phn_seq.replace(k, v) + phn_vocab.update(phn_seq.split(" ")) + all_lens.append(len(phn_seq.split(" "))) + with open(save_fn, "w") as f: + f.write(phn_seq) + logging.info(f"split {split} has {len(gs[split])} samples in total, skipped {skip} due to forbiden words") + + print(f"phn vocab size: {len(list(phn_vocab))}") + print("phn sequence stats: ") + print(f"longest: {max(all_lens)}") + print(f"shortest: {min(all_lens)}") + print(f"median: {np.quantile(all_lens, 0.5)}") + print(f"95 percentile longest: {np.quantile(all_lens, 0.95)}") + print("write vocabulary to ", vocab_fn) + with open(vocab_fn, "w") as f: + for i, phn in enumerate(list(phn_vocab)): + if i < len(list(phn_vocab)) - 1: + f.write(f"{str(i)} {phn}\n") + else: + f.write(f"{str(i)} {phn}") + + class mydataset(torch.utils.data.Dataset): + def __init__(self, split): + super().__init__() + self.data = gs[split] + def __len__(self): + return len(self.data) + def __getitem__(self, ind): + try: + segment_id, audio, sr, text, begin_time, end_time = self.data[ind]['segment_id'], torch.from_numpy(self.data[ind]['audio']['array']).float(), self.data[ind]['audio']['sampling_rate'], self.data[ind]['text'], self.data[ind]['begin_time'], self.data[ind]['end_time'] + except: + return None, None, None, None, None, None + + return segment_id, audio, sr, text, begin_time, end_time + def collate(self, batch): + res = {'segment_id': [], "audio": [], "sr": [], "text": [], "begin_time": [], "end_time": []} + for item in batch: + if item[0] != None: + res['segment_id'].append(item[0]) + res['audio'].append(item[1]) + res['sr'].append(item[2]) + res['text'].append(item[3]) + res['begin_time'].append(item[4]) + res['end_time'].append(item[5]) + return res + + + ## encodec codes extraction + logging.info("encodec encoding...") + train_dataset = mydataset('train') + train_loader = torch.torch.utils.data.DataLoader(train_dataset, batch_size=args.mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=train_dataset.collate) + validation_dataset = mydataset('validation') + validation_loader = torch.torch.utils.data.DataLoader(validation_dataset, batch_size=args.mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=validation_dataset.collate) + test_dataset = mydataset('test') + test_loader = torch.torch.utils.data.DataLoader(test_dataset, batch_size=args.mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=test_dataset.collate) + splits = ['validation', 'test', 'train'] + loaders = [validation_loader, test_loader, train_loader] + # splits = ['validation'] # for debug + # loaders = [validation_loader] + for split, loader in zip(splits, loaders): + skip = 0 + logging.info(f"now processing split {split}...") + mega_n_steps = int(np.ceil(len(gs[split]) / args.mega_batch_size)) + logging.info(f"partition the split {split} into {mega_n_steps} parts, each has {args.mega_batch_size} samples") + for m, mega_batch in enumerate(loader): + logging.info(f"====================================") + logging.info(f"====================================") + logging.info(f"now processing mega step {m+1}/{mega_n_steps}") + lengths = np.array(mega_batch['end_time']) - np.array(mega_batch['begin_time']) + sorted_inds = sort_by_audio_len(lengths) + for j in range(len(sorted_inds))[::-1]: + if lengths[sorted_inds[j]] < 0.2 or lengths[sorted_inds[j]] > args.len_cap: # skip samples that are too short (shorter than 0.2s), or too big (bigger than 80s) + skip += 1 + del sorted_inds[j] + + n_steps = int(np.ceil(len(sorted_inds) / args.batch_size)) + for n in tqdm.tqdm(range(n_steps), disable=True): + inds_used = sorted_inds[n*args.batch_size:(n+1)*args.batch_size] + audio_batch = [mega_batch['audio'][id] for id in inds_used] + sr_batch = [mega_batch['sr'][id] for id in inds_used] + segment_id_batch = [mega_batch['segment_id'][id] for id in inds_used] + text_batch = [mega_batch['text'][id] for id in inds_used] + padded_wav = torch.nn.utils.rnn.pad_sequence(audio_batch, batch_first=True).unsqueeze(1) # [B, T] -> [B, 1, T] + all_lens = [lengths[id] for id in inds_used] + with torch.no_grad(): + if max(all_lens) > args.max_len and len(all_lens) > 1: # NOTE decrease args.max_len if OOM, or chunk it into more than 2 forward passes + codes = [] + inwav = padded_wav.cuda() + codes.append(model.encode(inwav[:len(inwav)//2])[0].cpu()) + codes.append(model.encode(inwav[len(inwav)//2:])[0].cpu()) + codes = torch.cat(codes, dim=0) + else: + encoded_frames = model.encode(padded_wav.cuda()) + # logging.info(f"encoded_frames: {encoded_frames[0].shape}") + codes = encoded_frames[0].cpu() + + for i, length in enumerate(all_lens): + save_fn = os.path.join(codes_save_root, segment_id_batch[i]+".txt") + actual_len = round(length * args.model_code_sr) # 320 is downsample rate for this model + cur_code = codes[i].tolist() if type(codes) == list else codes[i, :, :actual_len].tolist() + write_array_to_txt_file(cur_code, save_fn) diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..ca0906d --- /dev/null +++ b/environment.yml @@ -0,0 +1,417 @@ +name: voicecraft +channels: + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=conda_forge + - _openmp_mutex=4.5=2_gnu + - aom=3.8.2=h59595ed_0 + - asttokens=2.4.1=pyhd8ed1ab_0 + - atk-1.0=2.38.0=hd4edc92_1 + - audioread=3.0.1=py39hf3d152e_1 + - backcall=0.2.0=pyh9f0ad1d_0 + - baumwelch=0.3.7=h00ab1b0_5 + - biopython=1.79=py39hb9d737c_3 + - brotli=1.1.0=hd590300_1 + - brotli-bin=1.1.0=hd590300_1 + - brotli-python=1.1.0=py39h3d6467e_1 + - bzip2=1.0.8=hd590300_5 + - ca-certificates=2024.2.2=hbcca054_0 + - cairo=1.18.0=h3faef2a_0 + - certifi=2024.2.2=pyhd8ed1ab_0 + - cffi=1.16.0=py39h7a31438_0 + - charset-normalizer=3.3.2=pyhd8ed1ab_0 + - click=8.1.7=unix_pyh707e725_0 + - colorama=0.4.6=pyhd8ed1ab_0 + - comm=0.2.2=pyhd8ed1ab_0 + - contourpy=1.2.0=py39h7633fee_0 + - cycler=0.12.1=pyhd8ed1ab_0 + - dataclassy=1.0.1=pyhd8ed1ab_0 + - dav1d=1.2.1=hd590300_0 + - debugpy=1.8.1=py39h3d6467e_0 + - decorator=5.1.1=pyhd8ed1ab_0 + - executing=2.0.1=pyhd8ed1ab_0 + - expat=2.6.2=h59595ed_0 + - ffmpeg=6.1.1=gpl_h38e077a_106 + - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 + - font-ttf-inconsolata=3.000=h77eed37_0 + - font-ttf-source-code-pro=2.038=h77eed37_0 + - font-ttf-ubuntu=0.83=h77eed37_1 + - fontconfig=2.14.2=h14ed4e7_0 + - fonts-conda-ecosystem=1=0 + - fonts-conda-forge=1=0 + - fonttools=4.49.0=py39hd1e30aa_0 + - freetype=2.12.1=h267a509_2 + - fribidi=1.0.10=h36c2ea0_0 + - gdk-pixbuf=2.42.10=h829c605_5 + - gettext=0.21.1=h27087fc_0 + - giflib=5.2.1=h0b41bf4_3 + - gmp=6.3.0=h59595ed_1 + - gnutls=3.7.9=hb077bed_0 + - graphite2=1.3.13=h58526e2_1001 + - graphviz=9.0.0=h78e8752_1 + - greenlet=3.0.3=py39h3d6467e_0 + - gtk2=2.24.33=h280cfa0_4 + - gts=0.7.6=h977cf35_4 + - harfbuzz=8.3.0=h3d44ed6_0 + - hdbscan=0.8.33=py39h44dd56e_4 + - icu=73.2=h59595ed_0 + - idna=3.6=pyhd8ed1ab_0 + - importlib-metadata=7.0.2=pyha770c72_0 + - importlib-resources=6.3.0=pyhd8ed1ab_0 + - importlib_metadata=7.0.2=hd8ed1ab_0 + - importlib_resources=6.3.0=pyhd8ed1ab_0 + - ipykernel=6.29.3=pyhd33586a_0 + - jedi=0.19.1=pyhd8ed1ab_0 + - joblib=1.3.2=pyhd8ed1ab_0 + - jupyter_client=8.6.1=pyhd8ed1ab_0 + - jupyter_core=5.7.2=py39hf3d152e_0 + - kaldi=5.5.1068=cpu_h31769b2_2 + - keyutils=1.6.1=h166bdaf_0 + - kiwisolver=1.4.5=py39h7633fee_1 + - kneed=0.8.5=pyhd8ed1ab_0 + - krb5=1.21.2=h659d440_0 + - lame=3.100=h166bdaf_1003 + - lazy_loader=0.3=pyhd8ed1ab_0 + - lcms2=2.16=hb7c19ff_0 + - ld_impl_linux-64=2.40=h41732ed_0 + - lerc=4.0.0=h27087fc_0 + - libabseil=20240116.1=cxx17_h59595ed_2 + - libass=0.17.1=h8fe9dca_1 + - libblas=3.9.0=21_linux64_openblas + - libbrotlicommon=1.1.0=hd590300_1 + - libbrotlidec=1.1.0=hd590300_1 + - libbrotlienc=1.1.0=hd590300_1 + - libcblas=3.9.0=21_linux64_openblas + - libclang-cpp15=15.0.7=default_hb11cfb5_4 + - libdeflate=1.19=hd590300_0 + - libdrm=2.4.120=hd590300_0 + - libedit=3.1.20191231=he28a2e2_2 + - libexpat=2.6.2=h59595ed_0 + - libffi=3.4.2=h7f98852_5 + - libflac=1.4.3=h59595ed_0 + - libgcc-ng=13.2.0=h807b86a_5 + - libgd=2.3.3=h119a65a_9 + - libgfortran-ng=13.2.0=h69a702a_5 + - libgfortran5=13.2.0=ha4646dd_5 + - libglib=2.80.0=hf2295e7_0 + - libgomp=13.2.0=h807b86a_5 + - libhwloc=2.9.3=default_h554bfaf_1009 + - libiconv=1.17=hd590300_2 + - libidn2=2.3.7=hd590300_0 + - libjpeg-turbo=3.0.0=hd590300_1 + - liblapack=3.9.0=21_linux64_openblas + - liblapacke=3.9.0=21_linux64_openblas + - libllvm14=14.0.6=hcd5def8_4 + - libllvm15=15.0.7=hb3ce162_4 + - libllvmspirv15=15.0.0=h0cdce71_1 + - libnsl=2.0.1=hd590300_0 + - libogg=1.3.4=h7f98852_1 + - libopenblas=0.3.26=pthreads_h413a1c8_0 + - libopenvino=2024.0.0=h2e90f83_1 + - libopenvino-auto-batch-plugin=2024.0.0=hd5fc58b_1 + - libopenvino-auto-plugin=2024.0.0=hd5fc58b_1 + - libopenvino-hetero-plugin=2024.0.0=h3ecfda7_1 + - libopenvino-intel-cpu-plugin=2024.0.0=h2e90f83_1 + - libopenvino-intel-gpu-plugin=2024.0.0=h2e90f83_1 + - libopenvino-ir-frontend=2024.0.0=h3ecfda7_1 + - libopenvino-onnx-frontend=2024.0.0=h757c851_1 + - libopenvino-paddle-frontend=2024.0.0=h757c851_1 + - libopenvino-pytorch-frontend=2024.0.0=h59595ed_1 + - libopenvino-tensorflow-frontend=2024.0.0=hca94c1a_1 + - libopenvino-tensorflow-lite-frontend=2024.0.0=h59595ed_1 + - libopus=1.3.1=h7f98852_1 + - libpciaccess=0.18=hd590300_0 + - libpng=1.6.43=h2797004_0 + - libpq=16.2=h33b98f1_0 + - libprotobuf=4.25.3=h08a7969_0 + - librosa=0.10.1=pyhd8ed1ab_0 + - librsvg=2.56.3=he3f83f7_1 + - libsndfile=1.2.2=hc60ed4a_1 + - libsodium=1.0.18=h36c2ea0_1 + - libsqlite=3.45.2=h2797004_0 + - libstdcxx-ng=13.2.0=h7e041cc_5 + - libtasn1=4.19.0=h166bdaf_0 + - libtiff=4.6.0=ha9c0a0a_2 + - libunistring=0.9.10=h7f98852_0 + - libuuid=2.38.1=h0b41bf4_0 + - libva=2.21.0=hd590300_0 + - libvorbis=1.3.7=h9c3ff4c_0 + - libvpx=1.14.0=h59595ed_0 + - libwebp=1.3.2=h658648e_1 + - libwebp-base=1.3.2=hd590300_0 + - libxcb=1.15=h0b41bf4_0 + - libxcrypt=4.4.36=hd590300_1 + - libxml2=2.12.5=h232c23b_0 + - libzlib=1.2.13=hd590300_5 + - llvm-spirv-15=15.0.0=h0cdce71_1 + - mad=0.15.1b=h9c3ff4c_1 + - markdown-it-py=3.0.0=pyhd8ed1ab_0 + - matplotlib-base=3.8.3=py39he9076e7_0 + - matplotlib-inline=0.1.6=pyhd8ed1ab_0 + - mdurl=0.1.2=pyhd8ed1ab_0 + - montreal-forced-aligner=2.2.17=pyhd8ed1ab_0 + - mpg123=1.32.4=h59595ed_0 + - msgpack-python=1.0.7=py39h7633fee_0 + - munkres=1.1.4=pyh9f0ad1d_0 + - ncurses=6.4=h59595ed_2 + - nest-asyncio=1.6.0=pyhd8ed1ab_0 + - nettle=3.9.1=h7ab15ed_0 + - ngram=1.3.14=h924138e_2 + - numba=0.59.0=py39h615d6bd_1 + - numpy=1.26.4=py39h474f0d3_0 + - ocl-icd=2.3.2=hd590300_0 + - openfst=1.8.2=h924138e_2 + - openh264=2.4.1=h59595ed_0 + - openjpeg=2.5.2=h488ebb8_0 + - openssl=3.2.1=hd590300_0 + - p11-kit=0.24.1=hc5aa10d_0 + - packaging=24.0=pyhd8ed1ab_0 + - pandas=2.2.1=py39hddac248_0 + - pango=1.52.1=ha41ecd1_0 + - parso=0.8.3=pyhd8ed1ab_0 + - patsy=0.5.6=pyhd8ed1ab_0 + - pcre2=10.43=hcad00b1_0 + - pexpect=4.9.0=pyhd8ed1ab_0 + - pgvector-python=0.2.5=pyhe093146_0 + - pickleshare=0.7.5=py_1003 + - pillow=10.2.0=py39had0adad_0 + - pip=24.0=pyhd8ed1ab_0 + - pixman=0.43.2=h59595ed_0 + - platformdirs=4.2.0=pyhd8ed1ab_0 + - pocl=5.0=h03a6ac1_2 + - pocl-core=5.0=hdaecddf_2 + - pocl-cpu=5.0=he901f76_2 + - pocl-cpu-minimal=5.0=h5ccd973_2 + - pocl-cuda=5.0=hdaecddf_2 + - pocl-remote=5.0=h5ccd973_2 + - pooch=1.8.1=pyhd8ed1ab_0 + - postgresql=16.2=h7387d8b_0 + - prompt-toolkit=3.0.42=pyha770c72_0 + - prompt_toolkit=3.0.42=hd8ed1ab_0 + - psutil=5.9.8=py39hd1e30aa_0 + - psycopg2=2.9.9=py39h89197e3_0 + - pthread-stubs=0.4=h36c2ea0_1001 + - ptyprocess=0.7.0=pyhd3deb0d_0 + - pugixml=1.14=h59595ed_0 + - pure_eval=0.2.2=pyhd8ed1ab_0 + - pycparser=2.21=pyhd8ed1ab_0 + - pygments=2.17.2=pyhd8ed1ab_0 + - pyparsing=3.1.2=pyhd8ed1ab_0 + - pysocks=1.7.1=pyha2e5f31_6 + - pysoundfile=0.12.1=pypyhd8ed1ab_1 + - python=3.9.18=h0755675_1_cpython + - python-tzdata=2024.1=pyhd8ed1ab_0 + - python_abi=3.9=4_cp39 + - pytz=2024.1=pyhd8ed1ab_0 + - pyyaml=6.0.1=py39hd1e30aa_1 + - pyzmq=25.1.2=py39h8c080ef_0 + - readline=8.2=h8228510_1 + - requests=2.31.0=pyhd8ed1ab_0 + - rich=13.7.1=pyhd8ed1ab_0 + - rich-click=1.7.4=pyhd8ed1ab_0 + - scikit-learn=1.2.2=py39hc236052_2 + - scipy=1.12.0=py39h474f0d3_2 + - seaborn=0.13.2=hd8ed1ab_0 + - seaborn-base=0.13.2=pyhd8ed1ab_0 + - setuptools=69.2.0=pyhd8ed1ab_0 + - six=1.16.0=pyh6c4a22f_0 + - snappy=1.1.10=h9fff704_0 + - sox=14.4.2=ha5cc309_1018 + - soxr=0.1.3=h0b41bf4_3 + - soxr-python=0.3.7=py39h44dd56e_0 + - sqlalchemy=2.0.28=py39hd1e30aa_0 + - sqlite=3.45.2=h2c6b66d_0 + - stack_data=0.6.2=pyhd8ed1ab_0 + - statsmodels=0.14.1=py39h44dd56e_0 + - svt-av1=1.8.0=h59595ed_0 + - tbb=2021.11.0=h00ab1b0_1 + - threadpoolctl=3.3.0=pyhc1e730c_0 + - tk=8.6.13=noxft_h4845f30_101 + - tornado=6.4=py39hd1e30aa_0 + - tqdm=4.66.2=pyhd8ed1ab_0 + - traitlets=5.14.2=pyhd8ed1ab_0 + - typing-extensions=4.10.0=hd8ed1ab_0 + - typing_extensions=4.10.0=pyha770c72_0 + - tzcode=2024a=h3f72095_0 + - tzdata=2024a=h0c530f3_0 + - unicodedata2=15.1.0=py39hd1e30aa_0 + - urllib3=2.2.1=pyhd8ed1ab_0 + - wcwidth=0.2.13=pyhd8ed1ab_0 + - wheel=0.42.0=pyhd8ed1ab_0 + - x264=1!164.3095=h166bdaf_2 + - x265=3.5=h924138e_3 + - xorg-fixesproto=5.0=h7f98852_1002 + - xorg-kbproto=1.0.7=h7f98852_1002 + - xorg-libice=1.1.1=hd590300_0 + - xorg-libsm=1.2.4=h7391055_0 + - xorg-libx11=1.8.7=h8ee46fc_0 + - xorg-libxau=1.0.11=hd590300_0 + - xorg-libxdmcp=1.1.3=h7f98852_0 + - xorg-libxext=1.3.4=h0b41bf4_2 + - xorg-libxfixes=5.0.3=h7f98852_1004 + - xorg-libxrender=0.9.11=hd590300_0 + - xorg-renderproto=0.11.1=h7f98852_1002 + - xorg-xextproto=7.3.0=h0b41bf4_1003 + - xorg-xproto=7.0.31=h7f98852_1007 + - xz=5.2.6=h166bdaf_0 + - yaml=0.2.5=h7f98852_2 + - zeromq=4.3.5=h59595ed_1 + - zipp=3.17.0=pyhd8ed1ab_0 + - zlib=1.2.13=hd590300_5 + - zstd=1.5.5=hfc55251_0 + - pip: + - absl-py==2.1.0 + - aiofiles==23.2.1 + - aiohttp==3.9.3 + - aiosignal==1.3.1 + - altair==5.2.0 + - antlr4-python3-runtime==4.9.3 + - anyio==4.3.0 + - async-timeout==4.0.3 + - attrs==23.2.0 + - av==11.0.0 + - babel==2.14.0 + - beautifulsoup4==4.12.3 + - bibtexparser==2.0.0b7 + - bleach==6.1.0 + - blis==0.7.11 + - catalogue==2.0.10 + - clldutils==3.22.2 + - cloudpickle==3.0.0 + - cmake==3.28.3 + - colorlog==6.8.2 + - confection==0.1.4 + - csvw==3.3.0 + - cymem==2.0.8 + - cython==0.29.37 + - datasets==2.16.0 + - defusedxml==0.7.1 + - demucs==4.0.1 + - dill==0.3.6 + - dlinfo==1.2.1 + - docopt==0.6.2 + - dora-search==0.1.12 + - einops==0.7.0 + - encodec==0.1.1 + - exceptiongroup==1.2.0 + - fastapi==0.110.0 + - fastjsonschema==2.19.1 + - ffmpy==0.3.2 + - filelock==3.13.1 + - flashy==0.0.2 + - frozenlist==1.4.1 + - fsspec==2023.10.0 + - gradio==3.50.2 + - gradio-client==0.6.1 + - grpcio==1.62.1 + - h11==0.14.0 + - httpcore==1.0.4 + - httpx==0.27.0 + - huggingface-hub==0.21.4 + - hydra-colorlog==1.2.0 + - hydra-core==1.3.2 + - ipython==8.12.3 + - isodate==0.6.1 + - jinja2==3.1.3 + - jsonschema==4.21.1 + - jsonschema-specifications==2023.12.1 + - julius==0.2.7 + - jupyterlab-pygments==0.3.0 + - lameenc==1.7.0 + - langcodes==3.3.0 + - language-tags==1.2.0 + - lit==18.1.1 + - llvmlite==0.42.0 + - lxml==5.1.0 + - markdown==3.5.2 + - markupsafe==2.1.5 + - mistune==3.0.2 + - mpmath==1.3.0 + - msgpack==1.0.8 + - multidict==6.0.5 + - multiprocess==0.70.14 + - murmurhash==1.0.10 + - nbclient==0.10.0 + - nbconvert==7.16.3 + - nbformat==5.10.3 + - networkx==3.2.1 + - num2words==0.5.13 + - nvidia-cublas-cu11==11.10.3.66 + - nvidia-cuda-cupti-cu11==11.7.101 + - nvidia-cuda-nvrtc-cu11==11.7.99 + - nvidia-cuda-runtime-cu11==11.7.99 + - nvidia-cudnn-cu11==8.5.0.96 + - nvidia-cufft-cu11==10.9.0.58 + - nvidia-curand-cu11==10.2.10.91 + - nvidia-cusolver-cu11==11.4.0.1 + - nvidia-cusparse-cu11==11.7.4.91 + - nvidia-nccl-cu11==2.14.3 + - nvidia-nvtx-cu11==11.7.91 + - omegaconf==2.3.0 + - openunmix==1.2.1 + - orjson==3.9.15 + - pandocfilters==1.5.1 + - pathlib-abc==0.1.1 + - pathy==0.11.0 + - pgvector==0.2.2 + - phonemizer==3.2.1 + - pipreqs==0.5.0 + - praatio==6.2.0 + - preshed==3.0.9 + - protobuf==4.25.3 + - pyarrow==15.0.2 + - pyarrow-hotfix==0.6 + - pydantic==1.10.14 + - pydub==0.25.1 + - pylatexenc==2.10 + - pynini==2.1.6 + - pypinyin==0.48.0 + - python-dateutil==2.9.0.post0 + - python-multipart==0.0.9 + - rdflib==7.0.0 + - referencing==0.33.0 + - regex==2023.12.25 + - responses==0.18.0 + - retrying==1.3.4 + - rfc3986==1.5.0 + - rpds-py==0.18.0 + - safetensors==0.4.2 + - segments==2.2.1 + - semantic-version==2.10.0 + - sentencepiece==0.2.0 + - smart-open==6.4.0 + - sniffio==1.3.1 + - soupsieve==2.5 + - spacy==3.5.2 + - spacy-legacy==3.0.12 + - spacy-loggers==1.0.5 + - srsly==2.4.8 + - starlette==0.36.3 + - submitit==1.5.1 + - sympy==1.12 + - tabulate==0.9.0 + - tensorboard==2.16.2 + - tensorboard-data-server==0.7.2 + - thinc==8.1.12 + - tinycss2==1.2.1 + - tokenizers==0.15.2 + - toolz==0.12.1 + - torch==2.0.1 + - torchaudio==2.0.2 + - torchmetrics==0.11.1 + - transformers==4.38.2 + - treetable==0.2.5 + - triton==2.0.0 + - typer==0.7.0 + - uritemplate==4.1.1 + - uvicorn==0.28.0 + - wasabi==1.1.2 + - webencodings==0.5.1 + - websockets==11.0.3 + - werkzeug==3.0.1 + - xformers==0.0.22 + - xxhash==3.4.1 + - yarg==0.1.9 + - yarl==1.9.4 +prefix: /home/pyp/miniconda3/envs/voicecraft diff --git a/models/voicecraft.py b/models/voicecraft.py index 4042cae..8d83729 100644 --- a/models/voicecraft.py +++ b/models/voicecraft.py @@ -504,7 +504,7 @@ class VoiceCraft(nn.Module): ntokens = [] top10acc = [] for k, (logit, target) in enumerate(zip(logits, targets)): - loss.append(F.cross_entropy(logit, target, reduction='mean', weight=self.class_weight.data if self.args.eog_weight!=1 else None)) + loss.append(F.cross_entropy(logit, target, reduction='mean')) top10acc.append(self.accuracy_metrics[k](logit.detach(), target)) ntokens.append(len(logit)) @@ -988,6 +988,8 @@ class VoiceCraft(nn.Module): for jj in range(1,self.args.n_codebooks): logits_adjust[jj][eog_inference] = -10000 logits_adjust[jj][self.args.empty_token] = -10000 + if cur_num_gen <= self.args.encodec_sr // 5: # this shouldn't happen, but just in case the model stopped too early + logits_adjust[0][eog_inference] = -10000 ##################### silence repetition handling ##################### if stop_repetition > 0 and prev_token in silence_tokens and consec_silence_count > stop_repetition: if logits_adjust[0, prev_token] < 0: @@ -1237,6 +1239,8 @@ class VoiceCraft(nn.Module): for jj in range(1,self.args.n_codebooks): logits_adjust[:,jj,eog_inference] = -10000 logits_adjust[:,jj,self.args.empty_token] = -10000 + if cur_num_gen <= self.args.encodec_sr // 5: # this shouldn't happen, but just in case the model stopped too early + logits_adjust[:,:,eog_inference] = -10000 ##################### silence repetition handling ##################### for b in range(batch_size): prev_token = prev_tokens[b] diff --git a/z_scripts/e830M.sh b/z_scripts/e830M.sh index 5394e83..ff83329 100644 --- a/z_scripts/e830M.sh +++ b/z_scripts/e830M.sh @@ -7,9 +7,9 @@ export WORLD_SIZE=4 dataset=gigaspeech mkdir -p ./logs/${dataset} -exp_root="/data/scratch/pyp/exp_pyp/VoiceCraft" +exp_root="path/to/store/exp_results" exp_name=e830M -dataset_dir="/data/scratch/pyp/datasets/gigaspeech_phn_enc_manifest/xl" +dataset_dir="path/to/stored_extracted_codes_and_phonemes/xl" # xs if you only extracted xs in previous step encodec_codes_folder_name="encodec_16khz_4codebooks" # export CUDA_LAUNCH_BLOCKING=1 # for debugging @@ -51,7 +51,7 @@ torchrun --nnodes=1 --rdzv-backend=c10d --rdzv-endpoint=localhost:41977 --nproc_ --text_vocab_size 100 \ --text_pad_token 100 \ --phn_folder_name "phonemes" \ ---manifest_name "manifest_large16khz_lessambi" \ +--manifest_name "manifest" \ --encodec_folder_name ${encodec_codes_folder_name} \ --audio_vocab_size 2048 \ --empty_token 2048 \