|
|
|
@ -711,7 +711,7 @@ class VoiceCraft(
|
|
|
|
|
##################### silence repetition handling #####################
|
|
|
|
|
# prepare the cache placeholder
|
|
|
|
|
# n_layers, 2, bsz, num_heads, src_len, head_dim
|
|
|
|
|
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None
|
|
|
|
|
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float16) if kvcache else None
|
|
|
|
|
# handle multi-span kv-cache
|
|
|
|
|
new_masked_span = False
|
|
|
|
|
|
|
|
|
@ -1011,7 +1011,7 @@ class VoiceCraft(
|
|
|
|
|
|
|
|
|
|
# prepare the cache placeholder
|
|
|
|
|
# n_layers, 2, bsz, num_heads, src_len, head_dim
|
|
|
|
|
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None
|
|
|
|
|
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float16) if kvcache else None
|
|
|
|
|
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
|
|
|
|
|
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
|
|
|
|
|
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
|
|
|
|
@ -1261,7 +1261,7 @@ class VoiceCraft(
|
|
|
|
|
|
|
|
|
|
# prepare the cache placeholder
|
|
|
|
|
# n_layers, 2, bsz, num_heads, src_len, head_dim
|
|
|
|
|
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None
|
|
|
|
|
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float16) if kvcache else None
|
|
|
|
|
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
|
|
|
|
|
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
|
|
|
|
|
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
|
|
|
|
|