This commit is contained in:
Neo Wang 2024-04-25 22:33:31 -05:00 committed by GitHub
commit c185a2928f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 238 additions and 52 deletions

View File

@ -4,6 +4,7 @@ import os, random
import numpy as np
import torch
import torchaudio
import psutil
from data.tokenizer import (
AudioTokenizer,
@ -40,7 +41,7 @@ def get_args():
@torch.no_grad()
def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_text, device, decode_config, prompt_end_frame):
def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_text, device, decode_config, prompt_end_frame, half=False):
# phonemize
text_tokens = [phn2num[phn] for phn in
tokenize_text(
@ -49,6 +50,7 @@ def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_token
]
text_tokens = torch.LongTensor(text_tokens).unsqueeze(0)
text_tokens_lens = torch.LongTensor([text_tokens.shape[-1]])
print("finished phonemize")
# encode audio
encoded_frames = tokenize_audio(audio_tokenizer, audio_fn, offset=0, num_frames=prompt_end_frame)
@ -56,12 +58,19 @@ def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_token
assert original_audio.ndim==3 and original_audio.shape[0] == 1 and original_audio.shape[2] == model_args.n_codebooks, original_audio.shape
logging.info(f"original audio length: {original_audio.shape[1]} codec frames, which is {original_audio.shape[1]/decode_config['codec_sr']:.2f} sec.")
process = psutil.Process()
print(f"finished encode; memory usage: {process.memory_info().rss}")
text_tokens = text_tokens.to(device)
if half:
text_tokens = text_tokens.half()
# forward
stime = time.time()
if decode_config['sample_batch_size'] <= 1:
logging.info(f"running inference with batch size 1")
concat_frames, gen_frames = model.inference_tts(
text_tokens.to(device),
text_tokens,
text_tokens_lens.to(device),
original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8]
top_k=decode_config['top_k'],
@ -74,7 +83,7 @@ def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_token
else:
logging.info(f"running inference with batch size {decode_config['sample_batch_size']}, i.e. return the shortest among {decode_config['sample_batch_size']} generations.")
concat_frames, gen_frames = model.inference_tts_batch(
text_tokens.to(device),
text_tokens,
text_tokens_lens.to(device),
original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8]
top_k=decode_config['top_k'],
@ -85,6 +94,9 @@ def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_token
batch_size = decode_config['sample_batch_size'],
silence_tokens=eval(decode_config['silence_tokens']) if type(decode_config['silence_tokens'])==str else decode_config['silence_tokens']
) # output is [1,K,T]
print("finished forward pass")
logging.info(f"inference on one sample take: {time.time() - stime:.4f} sec.")
logging.info(f"generated encoded_frames.shape: {gen_frames.shape}, which is {gen_frames.shape[-1]/decode_config['codec_sr']} sec.")

171
memtest.py Normal file
View File

@ -0,0 +1,171 @@
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["USER"] = "neow" # TODO change this to your username
import torch
import torchaudio
import numpy as np
import random
from data.tokenizer import (
AudioTokenizer,
TextTokenizer,
)
import subprocess as sp
import os
def get_gpu_memory():
command = "nvidia-smi --query-gpu=memory.free --format=csv"
memory_free_info = sp.check_output(command.split()).decode('ascii').split('\n')[:-1][1:]
memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
print(memory_free_values)
get_gpu_memory()
if __name__ == "__main__":
# load model, encodec, and phn2num
# # load model, tokenizer, and other necessary files
device = "cuda" if torch.cuda.is_available() else "cpu"
from models import voicecraft
# import models.voicecraft as voicecraft
voicecraft_name = "giga330M.pth" # or giga330M.pth
ckpt_fn = f"./pretrained_models/{voicecraft_name}"
encodec_fn = "./pretrained_models/encodec_4cb2048_giga.th"
if not os.path.exists(ckpt_fn):
os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/{voicecraft_name}\?download\=true")
os.system(f"mv {voicecraft_name}\?download\=true ./pretrained_models/{voicecraft_name}")
if not os.path.exists(encodec_fn):
os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th")
os.system(f"mv encodec_4cb2048_giga.th ./pretrained_models/encodec_4cb2048_giga.th")
ckpt = torch.load(ckpt_fn, map_location="cpu")
model = voicecraft.VoiceCraft(ckpt["config"])
model.load_state_dict(ckpt["model"])
model.to(device)
# model.half()
model.eval()
print("loaded model")
get_gpu_memory()
phn2num = ckpt['phn2num']
text_tokenizer = TextTokenizer(backend="espeak")
audio_tokenizer = AudioTokenizer(signature=encodec_fn, device=device) # will also put the neural codec model on gpu
# %%
# Prepare your audio
# point to the original audio whose speech you want to clone
# write down the transcript for the file, or run whisper to get the transcript (and you can modify it if it's not accurate), save it as a .txt file
orig_audio = "./demo/84_121550_000074_000000.wav"
orig_transcript = "But when I had approached so near to them The common object, which the sense deceives, Lost not by distance any of its marks,"
# move the audio and transcript to temp folder
temp_folder = "./demo/temp"
os.makedirs(temp_folder, exist_ok=True)
os.system(f"cp {orig_audio} {temp_folder}")
filename = os.path.splitext(orig_audio.split("/")[-1])[0]
with open(f"{temp_folder}/{filename}.txt", "w") as f:
f.write(orig_transcript)
# run MFA to get the alignment
align_temp = f"{temp_folder}/mfa_alignments"
# # if the above fails, it could be because the audio is too hard for the alignment model, increasing the beam size usually solves the issue
# !source ~/.bashrc && \
# conda activate voicecraft && \
# mfa align -v --clean -j 1 --output_format csv {temp_folder} \
# english_us_arpa english_us_arpa {align_temp} --beam 1000 --retry_beam 2000
# take a look at demo/temp/mfa_alignment, decide which part of the audio to use as prompt
cut_off_sec = 7.0 # NOTE: according to forced-alignment file demo/temp/mfa_alignments/84_121550_000074_000000.csv, the word "common" stop as 3.01 sec, this should be different for different audio
target_transcript = "But when I had approached so near to them The common I cannot believe that the same model can also do text to speech synthesis as well! I love shuffle 512 and janise"
# NOTE: 3 sec of reference is generally enough for high quality voice cloning, but longer is generally better, try e.g. 3~6 sec.
audio_fn = f"{temp_folder}/{filename}.wav"
info = torchaudio.info(audio_fn)
audio_dur = info.num_frames / info.sample_rate
assert cut_off_sec < audio_dur, f"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}"
prompt_end_frame = int(cut_off_sec * info.sample_rate)
# run the model to get the output
# hyperparameters for inference
codec_audio_sr = 16000
codec_sr = 50
top_k = 0
top_p = 0.8
temperature = 1
silence_tokens = [1388, 1898, 131]
kvcache = 0 # NOTE if OOM, change this to 0, or try the 330M model
# NOTE adjust the below three arguments if the generation is not as good
stop_repetition = 3 # NOTE if the model generate long silence, reduce the stop_repetition to 3, 2 or even 1
sample_batch_size = 1 # NOTE: if the if there are long silence or unnaturally strecthed words, increase sample_batch_size to 5 or higher. What this will do to the model is that the model will run sample_batch_size examples of the same audio, and pick the one that's the shortest. So if the speech rate of the generated is too fast change it to a smaller number.
seed = 1 # change seed if you are still unhappy with the result
def seed_everything(seed):
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
seed_everything(seed)
decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition,
'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr,
"silence_tokens": silence_tokens, "sample_batch_size": sample_batch_size}
from inference_tts_scale import inference_one_sample
print("before inference")
get_gpu_memory()
concated_audio, gen_audio = inference_one_sample(model, ckpt["config"], phn2num, text_tokenizer, audio_tokenizer,
audio_fn, target_transcript, device, decode_config, prompt_end_frame, False)
print("after inference")
get_gpu_memory()
# save segments for comparison
concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu()
# logging.info(f"length of the resynthesize orig audio: {orig_audio.shape}")
# display the audio
# from IPython.display import Audio
#
# print("concatenate prompt and generated:")
# display(Audio(concated_audio, rate=codec_audio_sr))
#
# print("generated:")
# display(Audio(gen_audio, rate=codec_audio_sr))
# # save the audio
# # output_dir
output_dir = "/home/pyp/VoiceCraft/demo/generated_tts"
os.makedirs(output_dir, exist_ok=True)
seg_save_fn_gen = f"{output_dir}/{os.path.basename(audio_fn)[:-4]}_gen_seed{seed}.wav"
seg_save_fn_concat = f"{output_dir}/{os.path.basename(audio_fn)[:-4]}_concat_seed{seed}.wav"
torchaudio.save(seg_save_fn_gen, gen_audio, codec_audio_sr)
torchaudio.save(seg_save_fn_concat, concated_audio, codec_audio_sr)
print("finished running")
# if you get error importing T5 in transformers
# try
# pip uninstall Pillow
# pip install Pillow
# you are might get warnings like WARNING:phonemizer:words count mismatch on 300.0% of the lines (3/1), this can be safely ignored
# %%

View File

@ -136,15 +136,15 @@ class VoiceCraft(
self.text_embedding = TokenEmbedding(
dim_model=self.args.d_model,
vocab_size=self.n_text_tokens,
vocab_size=self.n_text_tokens,
dropout=self.args.text_embedding_dropout
)
self.audio_embedding = nn.ModuleList(
[
TokenEmbedding(
dim_model=self.args.audio_embedding_dim,
vocab_size=self.n_audio_tokens[k],
dim_model=self.args.audio_embedding_dim,
vocab_size=self.n_audio_tokens[k],
dropout=self.args.audio_embedding_dropout
) for k in range(self.args.n_codebooks)
]
@ -177,13 +177,13 @@ class VoiceCraft(
num_layers=self.args.num_decoder_layers,
norm=LayerNorm(self.args.d_model),
)
self.predict_layer = nn.ModuleList(
[
nn.Sequential(nn.Linear(self.args.d_model, self.args.audio_vocab_size//2), nn.GELU(), nn.Linear(self.args.audio_vocab_size//2, self.n_audio_tokens[k])) for k in range(self.args.n_codebooks)
]
)
self.accuracy_metrics = nn.ModuleList(
[MulticlassAccuracy(
self.n_audio_tokens[k],
@ -194,7 +194,7 @@ class VoiceCraft(
) for k in range(self.args.n_codebooks)]
)
def prepare_mask_intervals(self, y_lens):
mask_intervals = []
non_mask_intervals = []
@ -230,12 +230,12 @@ class VoiceCraft(
temp_mask_end = gap - 1
mask_len = random.randint(temp_mask_start, temp_mask_end)
ends.append(start + mask_len)
mask_intervals.append([(s,e) for s,e in zip(starts, ends)])
non_mask_intervals.append([(ns,ne) for ns, ne in zip([0]+ends, starts+[y_len])])
return mask_intervals, non_mask_intervals
def rearrange(self, y, non_mask_intervals, mask_intervals):
reduced_eog = getattr(self.args, "reduced_eog", 0)
rearranged_y = []
@ -250,7 +250,7 @@ class VoiceCraft(
cur_y = [torch.cat([y[i, :, item[0]: item[1]], self.eog], dim=-1) for item in non_mask_intervals[i]] + [torch.cat([y[i, :, item[0]: item[1]], self.eog], dim=-1) for item in mask_intervals[i]] # eog is added to each section TODO this is not correct, I should add eog to non_mask_intervals if that segment is not the ending segment (as there is no way for the model to predict eog for those segments, and this will do harm to tts experiment, where the model randomly output eog for the first segment)
rearranged_y.append(cur_y)
return rearranged_y
def shift(self, rearranged_y):
shifted_y = []
patterns = []
@ -260,7 +260,7 @@ class VoiceCraft(
shifted_y.append([item[0].squeeze(0) for item in out]) # the first item is values, later two are indexes and mask
patterns.append(cur_patterns)
return shifted_y, patterns
def insert_mask(self, shifted_y):
inserted_y = []
mask_position = []
@ -286,7 +286,7 @@ class VoiceCraft(
inserted_y.append(cur_inserted_y)
mask_position.append(cur_mask_position)
return inserted_y, mask_position, mask_value
def cat_y(self, inserted_y, mask_position, y_lens):
reduced_eog = getattr(self.args, "reduced_eog", 0)
cated_y = []
@ -316,9 +316,9 @@ class VoiceCraft(
embedded_y = embedded_y.transpose(1,0) # [T,B,D]->[B,T,D]
for i in range(len(embedded_y)):
if len(mask_position[i]) > 0:
embedded_y[i, mask_position[i]] = self.mask_embedding[mask_value[i]]
embedded_y[i, mask_position[i]] = self.mask_embedding[mask_value[i]]
return embedded_y
def prepare_input_target(self, y, y_lens):
# rearrange y
# assume y shape: [B T K], K is n_codebooks
@ -355,16 +355,16 @@ class VoiceCraft(
inserted_y, mask_position, mask_value = self.insert_mask(shifted_y)
assert inserted_y[0][0].shape[0] == self.args.n_codebooks, inserted_y[0][0].shape[0]
assert inserted_y[0][1].shape == torch.Size((self.args.n_codebooks, 1)), f"this should be a mask, so should have shape {(self.args.n_codebooks, 1)}, but it's {inserted_y[0][1].shape}"
# then concat tensors that belong to the same sample (in order) then get the length of each sample, and then stack them in batch dimension, pad them with pad_token
cated_y, new_y_lens = self.cat_y(inserted_y, mask_position, y_lens) # KTB
assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], len(inserted_y)))
# embed remember to separately embed the mask tokens
embedded_y = self.embed_y(cated_y, mask_position, mask_value) #BTD
assert embedded_y.shape[1:] == torch.Size((max(new_y_lens), self.args.d_model)), embedded_y.shape
# positional embedding
y_input = self.audio_positional_embedding(embedded_y)
@ -381,9 +381,9 @@ class VoiceCraft(
non_mask_intervals = [[non_mask_positions[i]+1, non_mask_positions[i+1]] for i in range(len(non_mask_positions)-1)]
cur_logits_use = [logits[i, :, l:r] for l,r in non_mask_intervals]
logits_use.append(cur_logits_use)
return logits_use
def revert_pattern(self, patterns, logits_use):
logits_final = []
logit_masks = []
@ -403,9 +403,10 @@ class VoiceCraft(
return logits_final, logit_masks
@torch.autocast(device_type="cuda", dtype=torch.float16)
def dec_forward(
self,
x_input,
self,
x_input,
x_lens,
x_attention_mask,
x_padding_mask,
@ -449,7 +450,8 @@ class VoiceCraft(
xy_input = torch.cat([x_input, y_input], dim=1)
if past == None: # do not use kvcache
out, _ = self.decoder((xy_input, None), mask=xy_attn_mask)
out, _ = self.decoder((xy_input, None), mask=xy_attn_mask)
# out = out.half() # TODO: make this an option => only on if dtype = float16
return out[:, x_lens.max():], None
else: # use kvcache
if past.ndim > 3: # uses kvcache, only need to pass the last tokens, this doesn't work with multi-span speech editing yet
@ -469,6 +471,7 @@ class VoiceCraft(
else: # used kvcache
return out, present
@torch.autocast(device_type="cuda", dtype=torch.float16)
def forward(self, batch):
"""
Args:
@ -500,7 +503,7 @@ class VoiceCraft(
x_input = self.text_positional_embedding(x_input)
y_input, new_y_lens, targets, y_padding_mask, y_attention_mask, mask_position, patterns = self.prepare_input_target(y, y_lens)
y_out = self.dec_forward(
x_input,
x_input,
x_lens,
x_attention_mask,
x_padding_mask,
@ -511,13 +514,13 @@ class VoiceCraft(
)
y_out = y_out[0] # no kv-caching during training
assert y_out.shape == y_input.shape, f"y_out.shape: {y_out.shape}, y_input.shape: {y_input.shape}" # [B S D]
logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card]
# take out the mask token (using mask_position and new_y_lens) and revert (using function provided by self.pattern)
assert logits.shape[1] == self.args.n_codebooks and logits.shape[3] == self.n_audio_tokens[0], logits.shape
logits_use = self.remove_mask(logits, mask_position, new_y_lens)
# revert the pattern shift for each logits section in each sample
logits_final, logit_masks = self.revert_pattern(patterns, logits_use)
assert logits_final[0][0].shape[0] == self.args.n_codebooks and logits_final[0][0].shape[2] == self.n_audio_tokens[0], f"it is: {logits_final[0][0].shape}, but should be [K, T, card]"
@ -540,7 +543,7 @@ class VoiceCraft(
loss.append(F.cross_entropy(logit, target, reduction='mean'))
top10acc.append(self.accuracy_metrics[k](logit.detach(), target))
ntokens.append(len(logit))
all_ntokens = sum(ntokens)
if self.args.codebook_weight != None:
codebook_weight = eval(self.args.codebook_weight)
@ -557,7 +560,7 @@ class VoiceCraft(
"top10acc_by_codebook": top10acc_by_codebook,
"effective_ntoken": ntokens,
}
def inference(
self,
x: torch.Tensor,
@ -626,7 +629,7 @@ class VoiceCraft(
non_mask_intervals = [[
(ns, ne) for ns, ne in zip(ends, starts)
]]
# rearrange y
# will add have EOG in each section (SOG will be generated by the pattern class)
# but mask can be inserted later after we have shifted the input
@ -658,7 +661,7 @@ class VoiceCraft(
inserted_y, mask_position, mask_value = self.insert_mask(shifted_y)
assert inserted_y[0][0].shape[0] == self.args.n_codebooks, inserted_y[0][0].shape[0]
assert inserted_y[0][1].shape == torch.Size((self.args.n_codebooks, 1)), f"this should be a mask, so should have shape {(self.args.n_codebooks, 1)}, but it's {inserted_y[0][1].shape}"
# then concat tensors that belong to the same sample (in order) then get the length of each sample, and then stack them in batch dimension, pad them with pad_token
cated_y, new_y_lens = self.cat_y(inserted_y, mask_position, y_lens) # KTB
assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], len(inserted_y)))
@ -711,10 +714,10 @@ class VoiceCraft(
##################### silence repetition handling #####################
# prepare the cache placeholder
# n_layers, 2, bsz, num_heads, src_len, head_dim
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float16) if kvcache else None
# handle multi-span kv-cache
new_masked_span = False
def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_token, consec_silence_count, stop_repetition, silence_tokens, cur_num_gen):
if n_eog == 0:
logits_adjust = logits
@ -788,7 +791,7 @@ class VoiceCraft(
while True:
y_out, present = self.dec_forward(
x_input,
x_input,
x_lens,
x_attention_mask,
x_padding_mask,
@ -868,7 +871,7 @@ class VoiceCraft(
y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device)
new_y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device)
y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device)
assert len(generated) == num_mask, f"len(generated): {len(generated)}, num_mask: {num_mask}"
# # combine non_masked_span with generated spans
@ -899,7 +902,7 @@ class VoiceCraft(
expected_y_len = y_len - sum([item[1] - item[0] for item in mask_intervals[0]]) + sum([item - self.args.n_codebooks for item in num_gen])
assert res.shape == torch.Size((1, self.args.n_codebooks, expected_y_len)), f"res.shape: {res.shape}, expected_y_len: {expected_y_len}. y_len - sum([item[1] - item[0] for item in mask_interval]) + sum([item - self.args.n_codebooks for item in num_gen]): {y_len}-{sum([item[1] - item[0] for item in mask_interval])} + {sum([item - self.args.n_codebooks for item in num_gen])}"
if self.args.special_first:
res = res - int(self.args.n_special)
@ -980,7 +983,7 @@ class VoiceCraft(
assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape
embedded_y = embedded_y.sum(dim=0) # [K,S,B,D]->[S,B,D]
embedded_y = embedded_y.transpose(1,0) # [S,B,D]->[B,S,D]
# positional embedding
y_input = self.audio_positional_embedding(embedded_y)
@ -1011,7 +1014,7 @@ class VoiceCraft(
# prepare the cache placeholder
# n_layers, 2, bsz, num_heads, src_len, head_dim
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float16) if kvcache else None
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
@ -1067,7 +1070,7 @@ class VoiceCraft(
return samples, codebook_eog, prev_token, consec_silence_count
while True:
y_out, present = self.dec_forward(
x_input,
x_input,
x_lens,
x_attention_mask,
x_padding_mask,
@ -1091,9 +1094,9 @@ class VoiceCraft(
if self.args.eos > 0: # if we are using end-of-sentence token (which is used by default), eog shouldn't be used here, as there is no masked spans
for jj in range(self.args.n_codebooks):
logits[jj][self.args.eog] = -10000.
samples, codebook_eog, prev_token, consec_silence_count = sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_token, consec_silence_count, stop_repetition, silence_tokens, cur_num_gen)
cur_num_gen += 1
cur_generated.append(samples.squeeze(-1)) # [K,1] -> [K]
@ -1111,14 +1114,14 @@ class VoiceCraft(
break
else:
assert samples_emb.shape == torch.Size((1,1,self.args.d_model)), f"samples_emb.shape: {samples_emb.shape}"
embedded_y = torch.cat([embedded_y, samples_emb], dim=1)
y_input = self.audio_positional_embedding(embedded_y) # [B T D]
# make attention mask and padding mask
y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device)
new_y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device)
y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device)
assert len(generated) == 1, f"len(generated): {len(generated)}"
# revert the pattern
@ -1138,7 +1141,7 @@ class VoiceCraft(
flatten_gen.append(unshifted_span)
assert len(flatten_gen) == 1, len(flatten_gen)
# combine
res = [y[0], flatten_gen[0]]
res = torch.cat(res, dim=1).unsqueeze(0) # [K, new_t] -> [1, K, new_T]
@ -1230,7 +1233,7 @@ class VoiceCraft(
assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape
embedded_y = embedded_y.sum(dim=0) # [K,S,B,D]->[S,B,D]
embedded_y = embedded_y.transpose(1,0) # [S,B,D]->[B,S,D]
# positional embedding
y_input = self.audio_positional_embedding(embedded_y)
@ -1261,7 +1264,7 @@ class VoiceCraft(
# prepare the cache placeholder
# n_layers, 2, bsz, num_heads, src_len, head_dim
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float16) if kvcache else None
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
@ -1344,7 +1347,7 @@ class VoiceCraft(
else:
assert x_input.shape[0] == batch_size and x_padding_mask.shape[0] == batch_size and y_input.shape[0] == batch_size and new_y_lens.shape[0] == batch_size, f"x_input.shape: {x_input.shape}, x_padding_mask.shape: {x_padding_mask.shape}, y_input.shape: {y_input.shape}, new_y_lens.shape: {new_y_lens.shape}"
y_out, present = self.dec_forward(
x_input,
x_input,
x_lens,
x_attention_mask,
x_padding_mask,
@ -1370,7 +1373,7 @@ class VoiceCraft(
for jj in range(self.args.n_codebooks):
logits[:,jj,self.args.eog] = -10000.
samples, codebook_eog, prev_tokens, consec_silence_counts, keep = sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_tokens, consec_silence_counts, stop_repetition, silence_tokens, cur_num_gen, keep)
cur_num_gen += 1
if sum(codebook_eog) == 0: # no eog yet, keep batch_size of samples
assert keep == None
@ -1380,7 +1383,7 @@ class VoiceCraft(
assert keep != None
cur_generated = cur_generated[keep]
cur_generated.append(samples[keep].squeeze(-1))
else: # we are generating the rest eogs for the 'keep' sample
else: # we are generating the rest eogs for the 'keep' sample
cur_generated.append(samples[keep].squeeze(-1))
# samples.shape is [K,1]
@ -1397,14 +1400,14 @@ class VoiceCraft(
break
else:
assert samples_emb.shape == torch.Size((batch_size,1,self.args.d_model)), f"samples_emb.shape: {samples_emb.shape}"
embedded_y = torch.cat([embedded_y, samples_emb], dim=1)
y_input = self.audio_positional_embedding(embedded_y) # [B T D]
# make attention mask and padding mask
y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device)
new_y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device).repeat(batch_size)
y_padding_mask = torch.full((batch_size,new_y_lens[0]), False).to(y.device)
assert len(generated) == 1, f"len(generated): {len(generated)}"
# revert the pattern
@ -1424,7 +1427,7 @@ class VoiceCraft(
flatten_gen.append(unshifted_span)
assert len(flatten_gen) == 1, len(flatten_gen)
# combine
res = [y[0], flatten_gen[0]]
res = torch.cat(res, dim=1).unsqueeze(0) # [K, new_t] -> [1, K, new_T]