avoid zero size batch caused by grad_accumulation
This commit is contained in:
parent
778db3443d
commit
17061636f2
|
@ -32,6 +32,7 @@
|
||||||
"import torchaudio\n",
|
"import torchaudio\n",
|
||||||
"import numpy as np\n",
|
"import numpy as np\n",
|
||||||
"import random\n",
|
"import random\n",
|
||||||
|
"from argparse import Namespace\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from data.tokenizer import (\n",
|
"from data.tokenizer import (\n",
|
||||||
" AudioTokenizer,\n",
|
" AudioTokenizer,\n",
|
||||||
|
@ -45,7 +46,7 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# install MFA models and dictionaries if you haven't done so already\n",
|
"# install MFA models and dictionaries if you haven't done so already, already done in the dockerfile or envrionment setup\n",
|
||||||
"!source ~/.bashrc && \\\n",
|
"!source ~/.bashrc && \\\n",
|
||||||
" conda activate voicecraft && \\\n",
|
" conda activate voicecraft && \\\n",
|
||||||
" mfa model download dictionary english_us_arpa && \\\n",
|
" mfa model download dictionary english_us_arpa && \\\n",
|
||||||
|
@ -61,28 +62,38 @@
|
||||||
"# load model, encodec, and phn2num\n",
|
"# load model, encodec, and phn2num\n",
|
||||||
"# # load model, tokenizer, and other necessary files\n",
|
"# # load model, tokenizer, and other necessary files\n",
|
||||||
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
||||||
|
"voicecraft_name=\"giga330M.pth\" # or gigaHalfLibri330M_TTSEnhanced_max16s.pth, giga830M.pth\n",
|
||||||
|
"\n",
|
||||||
|
"# the old way of loading the model\n",
|
||||||
"from models import voicecraft\n",
|
"from models import voicecraft\n",
|
||||||
"#import models.voicecraft as voicecraft\n",
|
|
||||||
"voicecraft_name=\"gigaHalfLibri330M_TTSEnhanced_max16s.pth\" # or giga330M.pth, giga830M.pth\n",
|
|
||||||
"ckpt_fn =f\"./pretrained_models/{voicecraft_name}\"\n",
|
"ckpt_fn =f\"./pretrained_models/{voicecraft_name}\"\n",
|
||||||
"encodec_fn = \"./pretrained_models/encodec_4cb2048_giga.th\"\n",
|
|
||||||
"if not os.path.exists(ckpt_fn):\n",
|
"if not os.path.exists(ckpt_fn):\n",
|
||||||
" os.system(f\"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/{voicecraft_name}\\?download\\=true\")\n",
|
" os.system(f\"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/{voicecraft_name}\\?download\\=true\")\n",
|
||||||
" os.system(f\"mv {voicecraft_name}\\?download\\=true ./pretrained_models/{voicecraft_name}\")\n",
|
" os.system(f\"mv {voicecraft_name}\\?download\\=true ./pretrained_models/{voicecraft_name}\")\n",
|
||||||
"if not os.path.exists(encodec_fn):\n",
|
|
||||||
" os.system(f\"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th\")\n",
|
|
||||||
" os.system(f\"mv encodec_4cb2048_giga.th ./pretrained_models/encodec_4cb2048_giga.th\")\n",
|
|
||||||
"\n",
|
|
||||||
"ckpt = torch.load(ckpt_fn, map_location=\"cpu\")\n",
|
"ckpt = torch.load(ckpt_fn, map_location=\"cpu\")\n",
|
||||||
"model = voicecraft.VoiceCraft(ckpt[\"config\"])\n",
|
"model = voicecraft.VoiceCraft(ckpt[\"config\"])\n",
|
||||||
"model.load_state_dict(ckpt[\"model\"])\n",
|
"model.load_state_dict(ckpt[\"model\"])\n",
|
||||||
|
"phn2num = ckpt['phn2num']\n",
|
||||||
|
"config = vars(ckpt['config'])\n",
|
||||||
"model.to(device)\n",
|
"model.to(device)\n",
|
||||||
"model.eval()\n",
|
"model.eval()\n",
|
||||||
"\n",
|
"\n",
|
||||||
"phn2num = ckpt['phn2num']\n",
|
"# # the new way of loading the model, with huggingface, this doesn't work yet\n",
|
||||||
|
"# from models.voicecraft import VoiceCraftHF\n",
|
||||||
|
"# model = VoiceCraftHF.from_pretrained(f\"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}\")\n",
|
||||||
|
"# phn2num = model.args.phn2num # or model.args['phn2num']?\n",
|
||||||
|
"# config = model.config\n",
|
||||||
|
"# model.to(device)\n",
|
||||||
|
"# model.eval()\n",
|
||||||
"\n",
|
"\n",
|
||||||
"text_tokenizer = TextTokenizer(backend=\"espeak\")\n",
|
"\n",
|
||||||
"audio_tokenizer = AudioTokenizer(signature=encodec_fn, device=device) # will also put the neural codec model on gpu\n"
|
"encodec_fn = \"./pretrained_models/encodec_4cb2048_giga.th\"\n",
|
||||||
|
"if not os.path.exists(encodec_fn):\n",
|
||||||
|
" os.system(f\"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th\")\n",
|
||||||
|
" os.system(f\"mv encodec_4cb2048_giga.th ./pretrained_models/encodec_4cb2048_giga.th\")\n",
|
||||||
|
"audio_tokenizer = AudioTokenizer(signature=encodec_fn, device=device) # will also put the neural codec model on gpu\n",
|
||||||
|
"\n",
|
||||||
|
"text_tokenizer = TextTokenizer(backend=\"espeak\")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -148,7 +159,7 @@
|
||||||
"\n",
|
"\n",
|
||||||
"# NOTE adjust the below three arguments if the generation is not as good\n",
|
"# NOTE adjust the below three arguments if the generation is not as good\n",
|
||||||
"stop_repetition = 3 # NOTE if the model generate long silence, reduce the stop_repetition to 3, 2 or even 1\n",
|
"stop_repetition = 3 # NOTE if the model generate long silence, reduce the stop_repetition to 3, 2 or even 1\n",
|
||||||
"sample_batch_size = 2 # for gigaHalfLibri330M_TTSEnhanced_max16s.pth, 1 or 2 should be fine since the model is trained to do TTS, for the other two models, might need a higher number. NOTE: if the if there are long silence or unnaturally strecthed words, increase sample_batch_size to 5 or higher. What this will do to the model is that the model will run sample_batch_size examples of the same audio, and pick the one that's the shortest. So if the speech rate of the generated is too fast change it to a smaller number.\n",
|
"sample_batch_size = 4 # NOTE: if the if there are long silence or unnaturally strecthed words, increase sample_batch_size to 5 or higher. What this will do to the model is that the model will run sample_batch_size examples of the same audio, and pick the one that's the shortest. So if the speech rate of the generated is too fast change it to a smaller number.\n",
|
||||||
"seed = 1 # change seed if you are still unhappy with the result\n",
|
"seed = 1 # change seed if you are still unhappy with the result\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def seed_everything(seed):\n",
|
"def seed_everything(seed):\n",
|
||||||
|
@ -163,7 +174,7 @@
|
||||||
"\n",
|
"\n",
|
||||||
"decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, \"codec_audio_sr\": codec_audio_sr, \"codec_sr\": codec_sr, \"silence_tokens\": silence_tokens, \"sample_batch_size\": sample_batch_size}\n",
|
"decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, \"codec_audio_sr\": codec_audio_sr, \"codec_sr\": codec_sr, \"silence_tokens\": silence_tokens, \"sample_batch_size\": sample_batch_size}\n",
|
||||||
"from inference_tts_scale import inference_one_sample\n",
|
"from inference_tts_scale import inference_one_sample\n",
|
||||||
"concated_audio, gen_audio = inference_one_sample(model, ckpt[\"config\"], phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_transcript, device, decode_config, prompt_end_frame)\n",
|
"concated_audio, gen_audio = inference_one_sample(model, Namespace(**config), phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_transcript, device, decode_config, prompt_end_frame)\n",
|
||||||
" \n",
|
" \n",
|
||||||
"# save segments for comparison\n",
|
"# save segments for comparison\n",
|
||||||
"concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu()\n",
|
"concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu()\n",
|
||||||
|
@ -190,6 +201,13 @@
|
||||||
"\n",
|
"\n",
|
||||||
"# you are might get warnings like WARNING:phonemizer:words count mismatch on 300.0% of the lines (3/1), this can be safely ignored"
|
"# you are might get warnings like WARNING:phonemizer:words count mismatch on 300.0% of the lines (3/1), this can be safely ignored"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
|
|
@ -462,6 +462,8 @@ class VoiceCraft(nn.Module):
|
||||||
before padding.
|
before padding.
|
||||||
"""
|
"""
|
||||||
x, x_lens, y, y_lens = batch["x"], batch["x_lens"], batch["y"], batch["y_lens"]
|
x, x_lens, y, y_lens = batch["x"], batch["x_lens"], batch["y"], batch["y_lens"]
|
||||||
|
if len(x) == 0:
|
||||||
|
return None
|
||||||
x = x[:, :x_lens.max()] # this deal with gradient accumulation, where x_lens.max() might not be longer than the length of the current slice of x
|
x = x[:, :x_lens.max()] # this deal with gradient accumulation, where x_lens.max() might not be longer than the length of the current slice of x
|
||||||
y = y[:, :, :y_lens.max()]
|
y = y[:, :, :y_lens.max()]
|
||||||
assert x.ndim == 2, x.shape
|
assert x.ndim == 2, x.shape
|
||||||
|
|
|
@ -90,6 +90,8 @@ class Trainer:
|
||||||
cur_batch = {key: batch[key][cur_ind] for key in batch}
|
cur_batch = {key: batch[key][cur_ind] for key in batch}
|
||||||
with torch.cuda.amp.autocast(dtype=torch.float16 if self.args.precision=="float16" else torch.float32):
|
with torch.cuda.amp.autocast(dtype=torch.float16 if self.args.precision=="float16" else torch.float32):
|
||||||
out = self.model(cur_batch)
|
out = self.model(cur_batch)
|
||||||
|
if out == None:
|
||||||
|
continue
|
||||||
|
|
||||||
record_loss = out['loss'].detach().to(self.rank)
|
record_loss = out['loss'].detach().to(self.rank)
|
||||||
top10acc = out['top10acc'].to(self.rank)
|
top10acc = out['top10acc'].to(self.rank)
|
||||||
|
|
Loading…
Reference in New Issue