Compare commits
3 Commits
867a8c4672
...
f054cdb698
Author | SHA1 | Date |
---|---|---|
Neo Wang | f054cdb698 | |
pyp_l40 | 77d1d5a69c | |
Neo Wang | b81fc56736 |
|
@ -4,6 +4,7 @@ import os, random
|
|||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
import psutil
|
||||
|
||||
from data.tokenizer import (
|
||||
AudioTokenizer,
|
||||
|
@ -40,7 +41,7 @@ def get_args():
|
|||
|
||||
|
||||
@torch.no_grad()
|
||||
def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_text, device, decode_config, prompt_end_frame):
|
||||
def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_text, device, decode_config, prompt_end_frame, half=False):
|
||||
# phonemize
|
||||
text_tokens = [phn2num[phn] for phn in
|
||||
tokenize_text(
|
||||
|
@ -49,6 +50,7 @@ def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_token
|
|||
]
|
||||
text_tokens = torch.LongTensor(text_tokens).unsqueeze(0)
|
||||
text_tokens_lens = torch.LongTensor([text_tokens.shape[-1]])
|
||||
print("finished phonemize")
|
||||
|
||||
# encode audio
|
||||
encoded_frames = tokenize_audio(audio_tokenizer, audio_fn, offset=0, num_frames=prompt_end_frame)
|
||||
|
@ -56,12 +58,19 @@ def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_token
|
|||
assert original_audio.ndim==3 and original_audio.shape[0] == 1 and original_audio.shape[2] == model_args.n_codebooks, original_audio.shape
|
||||
logging.info(f"original audio length: {original_audio.shape[1]} codec frames, which is {original_audio.shape[1]/decode_config['codec_sr']:.2f} sec.")
|
||||
|
||||
process = psutil.Process()
|
||||
print(f"finished encode; memory usage: {process.memory_info().rss}")
|
||||
|
||||
text_tokens = text_tokens.to(device)
|
||||
if half:
|
||||
text_tokens = text_tokens.half()
|
||||
|
||||
# forward
|
||||
stime = time.time()
|
||||
if decode_config['sample_batch_size'] <= 1:
|
||||
logging.info(f"running inference with batch size 1")
|
||||
concat_frames, gen_frames = model.inference_tts(
|
||||
text_tokens.to(device),
|
||||
text_tokens,
|
||||
text_tokens_lens.to(device),
|
||||
original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8]
|
||||
top_k=decode_config['top_k'],
|
||||
|
@ -74,7 +83,7 @@ def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_token
|
|||
else:
|
||||
logging.info(f"running inference with batch size {decode_config['sample_batch_size']}, i.e. return the shortest among {decode_config['sample_batch_size']} generations.")
|
||||
concat_frames, gen_frames = model.inference_tts_batch(
|
||||
text_tokens.to(device),
|
||||
text_tokens,
|
||||
text_tokens_lens.to(device),
|
||||
original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8]
|
||||
top_k=decode_config['top_k'],
|
||||
|
@ -85,6 +94,9 @@ def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_token
|
|||
batch_size = decode_config['sample_batch_size'],
|
||||
silence_tokens=eval(decode_config['silence_tokens']) if type(decode_config['silence_tokens'])==str else decode_config['silence_tokens']
|
||||
) # output is [1,K,T]
|
||||
|
||||
print("finished forward pass")
|
||||
|
||||
logging.info(f"inference on one sample take: {time.time() - stime:.4f} sec.")
|
||||
|
||||
logging.info(f"generated encoded_frames.shape: {gen_frames.shape}, which is {gen_frames.shape[-1]/decode_config['codec_sr']} sec.")
|
||||
|
|
|
@ -0,0 +1,171 @@
|
|||
import os
|
||||
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
os.environ["USER"] = "neow" # TODO change this to your username
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
from data.tokenizer import (
|
||||
AudioTokenizer,
|
||||
TextTokenizer,
|
||||
)
|
||||
|
||||
import subprocess as sp
|
||||
import os
|
||||
|
||||
def get_gpu_memory():
|
||||
command = "nvidia-smi --query-gpu=memory.free --format=csv"
|
||||
memory_free_info = sp.check_output(command.split()).decode('ascii').split('\n')[:-1][1:]
|
||||
memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
|
||||
print(memory_free_values)
|
||||
|
||||
get_gpu_memory()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# load model, encodec, and phn2num
|
||||
# # load model, tokenizer, and other necessary files
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
from models import voicecraft
|
||||
|
||||
# import models.voicecraft as voicecraft
|
||||
voicecraft_name = "giga330M.pth" # or giga330M.pth
|
||||
ckpt_fn = f"./pretrained_models/{voicecraft_name}"
|
||||
encodec_fn = "./pretrained_models/encodec_4cb2048_giga.th"
|
||||
if not os.path.exists(ckpt_fn):
|
||||
os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/{voicecraft_name}\?download\=true")
|
||||
os.system(f"mv {voicecraft_name}\?download\=true ./pretrained_models/{voicecraft_name}")
|
||||
if not os.path.exists(encodec_fn):
|
||||
os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th")
|
||||
os.system(f"mv encodec_4cb2048_giga.th ./pretrained_models/encodec_4cb2048_giga.th")
|
||||
|
||||
ckpt = torch.load(ckpt_fn, map_location="cpu")
|
||||
model = voicecraft.VoiceCraft(ckpt["config"])
|
||||
model.load_state_dict(ckpt["model"])
|
||||
model.to(device)
|
||||
# model.half()
|
||||
model.eval()
|
||||
|
||||
print("loaded model")
|
||||
get_gpu_memory()
|
||||
|
||||
phn2num = ckpt['phn2num']
|
||||
|
||||
text_tokenizer = TextTokenizer(backend="espeak")
|
||||
audio_tokenizer = AudioTokenizer(signature=encodec_fn, device=device) # will also put the neural codec model on gpu
|
||||
|
||||
# %%
|
||||
|
||||
# Prepare your audio
|
||||
# point to the original audio whose speech you want to clone
|
||||
# write down the transcript for the file, or run whisper to get the transcript (and you can modify it if it's not accurate), save it as a .txt file
|
||||
orig_audio = "./demo/84_121550_000074_000000.wav"
|
||||
orig_transcript = "But when I had approached so near to them The common object, which the sense deceives, Lost not by distance any of its marks,"
|
||||
|
||||
# move the audio and transcript to temp folder
|
||||
temp_folder = "./demo/temp"
|
||||
os.makedirs(temp_folder, exist_ok=True)
|
||||
os.system(f"cp {orig_audio} {temp_folder}")
|
||||
filename = os.path.splitext(orig_audio.split("/")[-1])[0]
|
||||
with open(f"{temp_folder}/{filename}.txt", "w") as f:
|
||||
f.write(orig_transcript)
|
||||
# run MFA to get the alignment
|
||||
align_temp = f"{temp_folder}/mfa_alignments"
|
||||
|
||||
# # if the above fails, it could be because the audio is too hard for the alignment model, increasing the beam size usually solves the issue
|
||||
# !source ~/.bashrc && \
|
||||
# conda activate voicecraft && \
|
||||
# mfa align -v --clean -j 1 --output_format csv {temp_folder} \
|
||||
# english_us_arpa english_us_arpa {align_temp} --beam 1000 --retry_beam 2000
|
||||
|
||||
|
||||
# take a look at demo/temp/mfa_alignment, decide which part of the audio to use as prompt
|
||||
cut_off_sec = 7.0 # NOTE: according to forced-alignment file demo/temp/mfa_alignments/84_121550_000074_000000.csv, the word "common" stop as 3.01 sec, this should be different for different audio
|
||||
target_transcript = "But when I had approached so near to them The common I cannot believe that the same model can also do text to speech synthesis as well! I love shuffle 512 and janise"
|
||||
# NOTE: 3 sec of reference is generally enough for high quality voice cloning, but longer is generally better, try e.g. 3~6 sec.
|
||||
audio_fn = f"{temp_folder}/{filename}.wav"
|
||||
info = torchaudio.info(audio_fn)
|
||||
audio_dur = info.num_frames / info.sample_rate
|
||||
|
||||
assert cut_off_sec < audio_dur, f"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}"
|
||||
prompt_end_frame = int(cut_off_sec * info.sample_rate)
|
||||
|
||||
# run the model to get the output
|
||||
# hyperparameters for inference
|
||||
codec_audio_sr = 16000
|
||||
codec_sr = 50
|
||||
top_k = 0
|
||||
top_p = 0.8
|
||||
temperature = 1
|
||||
silence_tokens = [1388, 1898, 131]
|
||||
kvcache = 0 # NOTE if OOM, change this to 0, or try the 330M model
|
||||
|
||||
# NOTE adjust the below three arguments if the generation is not as good
|
||||
stop_repetition = 3 # NOTE if the model generate long silence, reduce the stop_repetition to 3, 2 or even 1
|
||||
sample_batch_size = 1 # NOTE: if the if there are long silence or unnaturally strecthed words, increase sample_batch_size to 5 or higher. What this will do to the model is that the model will run sample_batch_size examples of the same audio, and pick the one that's the shortest. So if the speech rate of the generated is too fast change it to a smaller number.
|
||||
seed = 1 # change seed if you are still unhappy with the result
|
||||
|
||||
|
||||
def seed_everything(seed):
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
seed_everything(seed)
|
||||
|
||||
decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition,
|
||||
'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr,
|
||||
"silence_tokens": silence_tokens, "sample_batch_size": sample_batch_size}
|
||||
from inference_tts_scale import inference_one_sample
|
||||
|
||||
print("before inference")
|
||||
get_gpu_memory()
|
||||
|
||||
concated_audio, gen_audio = inference_one_sample(model, ckpt["config"], phn2num, text_tokenizer, audio_tokenizer,
|
||||
audio_fn, target_transcript, device, decode_config, prompt_end_frame, False)
|
||||
print("after inference")
|
||||
get_gpu_memory()
|
||||
|
||||
# save segments for comparison
|
||||
concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu()
|
||||
# logging.info(f"length of the resynthesize orig audio: {orig_audio.shape}")
|
||||
|
||||
|
||||
# display the audio
|
||||
# from IPython.display import Audio
|
||||
#
|
||||
# print("concatenate prompt and generated:")
|
||||
# display(Audio(concated_audio, rate=codec_audio_sr))
|
||||
#
|
||||
# print("generated:")
|
||||
# display(Audio(gen_audio, rate=codec_audio_sr))
|
||||
|
||||
# # save the audio
|
||||
# # output_dir
|
||||
output_dir = "/home/pyp/VoiceCraft/demo/generated_tts"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
seg_save_fn_gen = f"{output_dir}/{os.path.basename(audio_fn)[:-4]}_gen_seed{seed}.wav"
|
||||
seg_save_fn_concat = f"{output_dir}/{os.path.basename(audio_fn)[:-4]}_concat_seed{seed}.wav"
|
||||
|
||||
torchaudio.save(seg_save_fn_gen, gen_audio, codec_audio_sr)
|
||||
torchaudio.save(seg_save_fn_concat, concated_audio, codec_audio_sr)
|
||||
|
||||
print("finished running")
|
||||
|
||||
# if you get error importing T5 in transformers
|
||||
# try
|
||||
# pip uninstall Pillow
|
||||
# pip install Pillow
|
||||
# you are might get warnings like WARNING:phonemizer:words count mismatch on 300.0% of the lines (3/1), this can be safely ignored
|
||||
|
||||
# %%
|
||||
|
||||
|
|
@ -136,15 +136,15 @@ class VoiceCraft(
|
|||
|
||||
self.text_embedding = TokenEmbedding(
|
||||
dim_model=self.args.d_model,
|
||||
vocab_size=self.n_text_tokens,
|
||||
vocab_size=self.n_text_tokens,
|
||||
dropout=self.args.text_embedding_dropout
|
||||
)
|
||||
|
||||
self.audio_embedding = nn.ModuleList(
|
||||
[
|
||||
TokenEmbedding(
|
||||
dim_model=self.args.audio_embedding_dim,
|
||||
vocab_size=self.n_audio_tokens[k],
|
||||
dim_model=self.args.audio_embedding_dim,
|
||||
vocab_size=self.n_audio_tokens[k],
|
||||
dropout=self.args.audio_embedding_dropout
|
||||
) for k in range(self.args.n_codebooks)
|
||||
]
|
||||
|
@ -177,13 +177,13 @@ class VoiceCraft(
|
|||
num_layers=self.args.num_decoder_layers,
|
||||
norm=LayerNorm(self.args.d_model),
|
||||
)
|
||||
|
||||
|
||||
self.predict_layer = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(nn.Linear(self.args.d_model, self.args.audio_vocab_size//2), nn.GELU(), nn.Linear(self.args.audio_vocab_size//2, self.n_audio_tokens[k])) for k in range(self.args.n_codebooks)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
self.accuracy_metrics = nn.ModuleList(
|
||||
[MulticlassAccuracy(
|
||||
self.n_audio_tokens[k],
|
||||
|
@ -194,7 +194,7 @@ class VoiceCraft(
|
|||
) for k in range(self.args.n_codebooks)]
|
||||
)
|
||||
|
||||
|
||||
|
||||
def prepare_mask_intervals(self, y_lens):
|
||||
mask_intervals = []
|
||||
non_mask_intervals = []
|
||||
|
@ -230,12 +230,12 @@ class VoiceCraft(
|
|||
temp_mask_end = gap - 1
|
||||
mask_len = random.randint(temp_mask_start, temp_mask_end)
|
||||
ends.append(start + mask_len)
|
||||
|
||||
|
||||
mask_intervals.append([(s,e) for s,e in zip(starts, ends)])
|
||||
non_mask_intervals.append([(ns,ne) for ns, ne in zip([0]+ends, starts+[y_len])])
|
||||
|
||||
return mask_intervals, non_mask_intervals
|
||||
|
||||
|
||||
def rearrange(self, y, non_mask_intervals, mask_intervals):
|
||||
reduced_eog = getattr(self.args, "reduced_eog", 0)
|
||||
rearranged_y = []
|
||||
|
@ -250,7 +250,7 @@ class VoiceCraft(
|
|||
cur_y = [torch.cat([y[i, :, item[0]: item[1]], self.eog], dim=-1) for item in non_mask_intervals[i]] + [torch.cat([y[i, :, item[0]: item[1]], self.eog], dim=-1) for item in mask_intervals[i]] # eog is added to each section TODO this is not correct, I should add eog to non_mask_intervals if that segment is not the ending segment (as there is no way for the model to predict eog for those segments, and this will do harm to tts experiment, where the model randomly output eog for the first segment)
|
||||
rearranged_y.append(cur_y)
|
||||
return rearranged_y
|
||||
|
||||
|
||||
def shift(self, rearranged_y):
|
||||
shifted_y = []
|
||||
patterns = []
|
||||
|
@ -260,7 +260,7 @@ class VoiceCraft(
|
|||
shifted_y.append([item[0].squeeze(0) for item in out]) # the first item is values, later two are indexes and mask
|
||||
patterns.append(cur_patterns)
|
||||
return shifted_y, patterns
|
||||
|
||||
|
||||
def insert_mask(self, shifted_y):
|
||||
inserted_y = []
|
||||
mask_position = []
|
||||
|
@ -286,7 +286,7 @@ class VoiceCraft(
|
|||
inserted_y.append(cur_inserted_y)
|
||||
mask_position.append(cur_mask_position)
|
||||
return inserted_y, mask_position, mask_value
|
||||
|
||||
|
||||
def cat_y(self, inserted_y, mask_position, y_lens):
|
||||
reduced_eog = getattr(self.args, "reduced_eog", 0)
|
||||
cated_y = []
|
||||
|
@ -316,9 +316,9 @@ class VoiceCraft(
|
|||
embedded_y = embedded_y.transpose(1,0) # [T,B,D]->[B,T,D]
|
||||
for i in range(len(embedded_y)):
|
||||
if len(mask_position[i]) > 0:
|
||||
embedded_y[i, mask_position[i]] = self.mask_embedding[mask_value[i]]
|
||||
embedded_y[i, mask_position[i]] = self.mask_embedding[mask_value[i]]
|
||||
return embedded_y
|
||||
|
||||
|
||||
def prepare_input_target(self, y, y_lens):
|
||||
# rearrange y
|
||||
# assume y shape: [B T K], K is n_codebooks
|
||||
|
@ -355,16 +355,16 @@ class VoiceCraft(
|
|||
inserted_y, mask_position, mask_value = self.insert_mask(shifted_y)
|
||||
assert inserted_y[0][0].shape[0] == self.args.n_codebooks, inserted_y[0][0].shape[0]
|
||||
assert inserted_y[0][1].shape == torch.Size((self.args.n_codebooks, 1)), f"this should be a mask, so should have shape {(self.args.n_codebooks, 1)}, but it's {inserted_y[0][1].shape}"
|
||||
|
||||
|
||||
# then concat tensors that belong to the same sample (in order) then get the length of each sample, and then stack them in batch dimension, pad them with pad_token
|
||||
cated_y, new_y_lens = self.cat_y(inserted_y, mask_position, y_lens) # KTB
|
||||
assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], len(inserted_y)))
|
||||
|
||||
|
||||
|
||||
# embed remember to separately embed the mask tokens
|
||||
embedded_y = self.embed_y(cated_y, mask_position, mask_value) #BTD
|
||||
assert embedded_y.shape[1:] == torch.Size((max(new_y_lens), self.args.d_model)), embedded_y.shape
|
||||
|
||||
|
||||
# positional embedding
|
||||
y_input = self.audio_positional_embedding(embedded_y)
|
||||
|
||||
|
@ -381,9 +381,9 @@ class VoiceCraft(
|
|||
non_mask_intervals = [[non_mask_positions[i]+1, non_mask_positions[i+1]] for i in range(len(non_mask_positions)-1)]
|
||||
cur_logits_use = [logits[i, :, l:r] for l,r in non_mask_intervals]
|
||||
logits_use.append(cur_logits_use)
|
||||
|
||||
|
||||
return logits_use
|
||||
|
||||
|
||||
def revert_pattern(self, patterns, logits_use):
|
||||
logits_final = []
|
||||
logit_masks = []
|
||||
|
@ -403,9 +403,10 @@ class VoiceCraft(
|
|||
|
||||
return logits_final, logit_masks
|
||||
|
||||
@torch.autocast(device_type="cuda", dtype=torch.float16)
|
||||
def dec_forward(
|
||||
self,
|
||||
x_input,
|
||||
self,
|
||||
x_input,
|
||||
x_lens,
|
||||
x_attention_mask,
|
||||
x_padding_mask,
|
||||
|
@ -449,7 +450,8 @@ class VoiceCraft(
|
|||
xy_input = torch.cat([x_input, y_input], dim=1)
|
||||
|
||||
if past == None: # do not use kvcache
|
||||
out, _ = self.decoder((xy_input, None), mask=xy_attn_mask)
|
||||
out, _ = self.decoder((xy_input, None), mask=xy_attn_mask)
|
||||
# out = out.half() # TODO: make this an option => only on if dtype = float16
|
||||
return out[:, x_lens.max():], None
|
||||
else: # use kvcache
|
||||
if past.ndim > 3: # uses kvcache, only need to pass the last tokens, this doesn't work with multi-span speech editing yet
|
||||
|
@ -469,6 +471,7 @@ class VoiceCraft(
|
|||
else: # used kvcache
|
||||
return out, present
|
||||
|
||||
@torch.autocast(device_type="cuda", dtype=torch.float16)
|
||||
def forward(self, batch):
|
||||
"""
|
||||
Args:
|
||||
|
@ -500,7 +503,7 @@ class VoiceCraft(
|
|||
x_input = self.text_positional_embedding(x_input)
|
||||
y_input, new_y_lens, targets, y_padding_mask, y_attention_mask, mask_position, patterns = self.prepare_input_target(y, y_lens)
|
||||
y_out = self.dec_forward(
|
||||
x_input,
|
||||
x_input,
|
||||
x_lens,
|
||||
x_attention_mask,
|
||||
x_padding_mask,
|
||||
|
@ -511,13 +514,13 @@ class VoiceCraft(
|
|||
)
|
||||
y_out = y_out[0] # no kv-caching during training
|
||||
assert y_out.shape == y_input.shape, f"y_out.shape: {y_out.shape}, y_input.shape: {y_input.shape}" # [B S D]
|
||||
|
||||
|
||||
logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card]
|
||||
# take out the mask token (using mask_position and new_y_lens) and revert (using function provided by self.pattern)
|
||||
assert logits.shape[1] == self.args.n_codebooks and logits.shape[3] == self.n_audio_tokens[0], logits.shape
|
||||
|
||||
logits_use = self.remove_mask(logits, mask_position, new_y_lens)
|
||||
|
||||
|
||||
# revert the pattern shift for each logits section in each sample
|
||||
logits_final, logit_masks = self.revert_pattern(patterns, logits_use)
|
||||
assert logits_final[0][0].shape[0] == self.args.n_codebooks and logits_final[0][0].shape[2] == self.n_audio_tokens[0], f"it is: {logits_final[0][0].shape}, but should be [K, T, card]"
|
||||
|
@ -540,7 +543,7 @@ class VoiceCraft(
|
|||
loss.append(F.cross_entropy(logit, target, reduction='mean'))
|
||||
top10acc.append(self.accuracy_metrics[k](logit.detach(), target))
|
||||
ntokens.append(len(logit))
|
||||
|
||||
|
||||
all_ntokens = sum(ntokens)
|
||||
if self.args.codebook_weight != None:
|
||||
codebook_weight = eval(self.args.codebook_weight)
|
||||
|
@ -557,7 +560,7 @@ class VoiceCraft(
|
|||
"top10acc_by_codebook": top10acc_by_codebook,
|
||||
"effective_ntoken": ntokens,
|
||||
}
|
||||
|
||||
|
||||
def inference(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
|
@ -626,7 +629,7 @@ class VoiceCraft(
|
|||
non_mask_intervals = [[
|
||||
(ns, ne) for ns, ne in zip(ends, starts)
|
||||
]]
|
||||
|
||||
|
||||
# rearrange y
|
||||
# will add have EOG in each section (SOG will be generated by the pattern class)
|
||||
# but mask can be inserted later after we have shifted the input
|
||||
|
@ -658,7 +661,7 @@ class VoiceCraft(
|
|||
inserted_y, mask_position, mask_value = self.insert_mask(shifted_y)
|
||||
assert inserted_y[0][0].shape[0] == self.args.n_codebooks, inserted_y[0][0].shape[0]
|
||||
assert inserted_y[0][1].shape == torch.Size((self.args.n_codebooks, 1)), f"this should be a mask, so should have shape {(self.args.n_codebooks, 1)}, but it's {inserted_y[0][1].shape}"
|
||||
|
||||
|
||||
# then concat tensors that belong to the same sample (in order) then get the length of each sample, and then stack them in batch dimension, pad them with pad_token
|
||||
cated_y, new_y_lens = self.cat_y(inserted_y, mask_position, y_lens) # KTB
|
||||
assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], len(inserted_y)))
|
||||
|
@ -711,10 +714,10 @@ class VoiceCraft(
|
|||
##################### silence repetition handling #####################
|
||||
# prepare the cache placeholder
|
||||
# n_layers, 2, bsz, num_heads, src_len, head_dim
|
||||
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None
|
||||
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float16) if kvcache else None
|
||||
# handle multi-span kv-cache
|
||||
new_masked_span = False
|
||||
|
||||
|
||||
def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_token, consec_silence_count, stop_repetition, silence_tokens, cur_num_gen):
|
||||
if n_eog == 0:
|
||||
logits_adjust = logits
|
||||
|
@ -788,7 +791,7 @@ class VoiceCraft(
|
|||
|
||||
while True:
|
||||
y_out, present = self.dec_forward(
|
||||
x_input,
|
||||
x_input,
|
||||
x_lens,
|
||||
x_attention_mask,
|
||||
x_padding_mask,
|
||||
|
@ -868,7 +871,7 @@ class VoiceCraft(
|
|||
y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device)
|
||||
new_y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device)
|
||||
y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device)
|
||||
|
||||
|
||||
assert len(generated) == num_mask, f"len(generated): {len(generated)}, num_mask: {num_mask}"
|
||||
|
||||
# # combine non_masked_span with generated spans
|
||||
|
@ -899,7 +902,7 @@ class VoiceCraft(
|
|||
|
||||
expected_y_len = y_len - sum([item[1] - item[0] for item in mask_intervals[0]]) + sum([item - self.args.n_codebooks for item in num_gen])
|
||||
assert res.shape == torch.Size((1, self.args.n_codebooks, expected_y_len)), f"res.shape: {res.shape}, expected_y_len: {expected_y_len}. y_len - sum([item[1] - item[0] for item in mask_interval]) + sum([item - self.args.n_codebooks for item in num_gen]): {y_len}-{sum([item[1] - item[0] for item in mask_interval])} + {sum([item - self.args.n_codebooks for item in num_gen])}"
|
||||
|
||||
|
||||
if self.args.special_first:
|
||||
res = res - int(self.args.n_special)
|
||||
|
||||
|
@ -980,7 +983,7 @@ class VoiceCraft(
|
|||
assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape
|
||||
embedded_y = embedded_y.sum(dim=0) # [K,S,B,D]->[S,B,D]
|
||||
embedded_y = embedded_y.transpose(1,0) # [S,B,D]->[B,S,D]
|
||||
|
||||
|
||||
# positional embedding
|
||||
y_input = self.audio_positional_embedding(embedded_y)
|
||||
|
||||
|
@ -1011,7 +1014,7 @@ class VoiceCraft(
|
|||
|
||||
# prepare the cache placeholder
|
||||
# n_layers, 2, bsz, num_heads, src_len, head_dim
|
||||
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None
|
||||
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float16) if kvcache else None
|
||||
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
|
||||
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
|
||||
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
|
||||
|
@ -1067,7 +1070,7 @@ class VoiceCraft(
|
|||
return samples, codebook_eog, prev_token, consec_silence_count
|
||||
while True:
|
||||
y_out, present = self.dec_forward(
|
||||
x_input,
|
||||
x_input,
|
||||
x_lens,
|
||||
x_attention_mask,
|
||||
x_padding_mask,
|
||||
|
@ -1091,9 +1094,9 @@ class VoiceCraft(
|
|||
if self.args.eos > 0: # if we are using end-of-sentence token (which is used by default), eog shouldn't be used here, as there is no masked spans
|
||||
for jj in range(self.args.n_codebooks):
|
||||
logits[jj][self.args.eog] = -10000.
|
||||
|
||||
|
||||
samples, codebook_eog, prev_token, consec_silence_count = sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_token, consec_silence_count, stop_repetition, silence_tokens, cur_num_gen)
|
||||
|
||||
|
||||
cur_num_gen += 1
|
||||
cur_generated.append(samples.squeeze(-1)) # [K,1] -> [K]
|
||||
|
||||
|
@ -1111,14 +1114,14 @@ class VoiceCraft(
|
|||
break
|
||||
else:
|
||||
assert samples_emb.shape == torch.Size((1,1,self.args.d_model)), f"samples_emb.shape: {samples_emb.shape}"
|
||||
|
||||
|
||||
embedded_y = torch.cat([embedded_y, samples_emb], dim=1)
|
||||
y_input = self.audio_positional_embedding(embedded_y) # [B T D]
|
||||
# make attention mask and padding mask
|
||||
y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device)
|
||||
new_y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device)
|
||||
y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device)
|
||||
|
||||
|
||||
assert len(generated) == 1, f"len(generated): {len(generated)}"
|
||||
|
||||
# revert the pattern
|
||||
|
@ -1138,7 +1141,7 @@ class VoiceCraft(
|
|||
|
||||
flatten_gen.append(unshifted_span)
|
||||
assert len(flatten_gen) == 1, len(flatten_gen)
|
||||
|
||||
|
||||
# combine
|
||||
res = [y[0], flatten_gen[0]]
|
||||
res = torch.cat(res, dim=1).unsqueeze(0) # [K, new_t] -> [1, K, new_T]
|
||||
|
@ -1230,7 +1233,7 @@ class VoiceCraft(
|
|||
assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape
|
||||
embedded_y = embedded_y.sum(dim=0) # [K,S,B,D]->[S,B,D]
|
||||
embedded_y = embedded_y.transpose(1,0) # [S,B,D]->[B,S,D]
|
||||
|
||||
|
||||
# positional embedding
|
||||
y_input = self.audio_positional_embedding(embedded_y)
|
||||
|
||||
|
@ -1261,7 +1264,7 @@ class VoiceCraft(
|
|||
|
||||
# prepare the cache placeholder
|
||||
# n_layers, 2, bsz, num_heads, src_len, head_dim
|
||||
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None
|
||||
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float16) if kvcache else None
|
||||
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
|
||||
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
|
||||
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
|
||||
|
@ -1344,7 +1347,7 @@ class VoiceCraft(
|
|||
else:
|
||||
assert x_input.shape[0] == batch_size and x_padding_mask.shape[0] == batch_size and y_input.shape[0] == batch_size and new_y_lens.shape[0] == batch_size, f"x_input.shape: {x_input.shape}, x_padding_mask.shape: {x_padding_mask.shape}, y_input.shape: {y_input.shape}, new_y_lens.shape: {new_y_lens.shape}"
|
||||
y_out, present = self.dec_forward(
|
||||
x_input,
|
||||
x_input,
|
||||
x_lens,
|
||||
x_attention_mask,
|
||||
x_padding_mask,
|
||||
|
@ -1370,7 +1373,7 @@ class VoiceCraft(
|
|||
for jj in range(self.args.n_codebooks):
|
||||
logits[:,jj,self.args.eog] = -10000.
|
||||
samples, codebook_eog, prev_tokens, consec_silence_counts, keep = sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_tokens, consec_silence_counts, stop_repetition, silence_tokens, cur_num_gen, keep)
|
||||
|
||||
|
||||
cur_num_gen += 1
|
||||
if sum(codebook_eog) == 0: # no eog yet, keep batch_size of samples
|
||||
assert keep == None
|
||||
|
@ -1380,7 +1383,7 @@ class VoiceCraft(
|
|||
assert keep != None
|
||||
cur_generated = cur_generated[keep]
|
||||
cur_generated.append(samples[keep].squeeze(-1))
|
||||
else: # we are generating the rest eogs for the 'keep' sample
|
||||
else: # we are generating the rest eogs for the 'keep' sample
|
||||
cur_generated.append(samples[keep].squeeze(-1))
|
||||
|
||||
# samples.shape is [K,1]
|
||||
|
@ -1397,14 +1400,14 @@ class VoiceCraft(
|
|||
break
|
||||
else:
|
||||
assert samples_emb.shape == torch.Size((batch_size,1,self.args.d_model)), f"samples_emb.shape: {samples_emb.shape}"
|
||||
|
||||
|
||||
embedded_y = torch.cat([embedded_y, samples_emb], dim=1)
|
||||
y_input = self.audio_positional_embedding(embedded_y) # [B T D]
|
||||
# make attention mask and padding mask
|
||||
y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device)
|
||||
new_y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device).repeat(batch_size)
|
||||
y_padding_mask = torch.full((batch_size,new_y_lens[0]), False).to(y.device)
|
||||
|
||||
|
||||
assert len(generated) == 1, f"len(generated): {len(generated)}"
|
||||
|
||||
# revert the pattern
|
||||
|
@ -1424,7 +1427,7 @@ class VoiceCraft(
|
|||
|
||||
flatten_gen.append(unshifted_span)
|
||||
assert len(flatten_gen) == 1, len(flatten_gen)
|
||||
|
||||
|
||||
# combine
|
||||
res = [y[0], flatten_gen[0]]
|
||||
res = torch.cat(res, dim=1).unsqueeze(0) # [K, new_t] -> [1, K, new_T]
|
||||
|
|
|
@ -78,7 +78,6 @@ class Trainer:
|
|||
|
||||
if self.rank == 0 and self.progress['step'] % self.args.tb_write_every_n_steps == 0:
|
||||
self.writer.add_scalar("train/lr", cur_lr, self.progress['step'])
|
||||
self.wandb.log({"train/lr": cur_lr}, step=self.progress['step'])
|
||||
|
||||
all_inds = list(range(len(batch['y'])))
|
||||
sum_losses = 0
|
||||
|
|
Loading…
Reference in New Issue