Compare commits

...

2 Commits

Author SHA1 Message Date
Forkoz f26822614b
Merge 6dda1a4f32 into da6d34e26e 2024-04-28 11:28:45 +05:30
Forkoz 6dda1a4f32
Float16 KV Cache in voicecraft.py 2024-04-05 17:52:28 +00:00
1 changed files with 3 additions and 3 deletions

View File

@ -711,7 +711,7 @@ class VoiceCraft(
##################### silence repetition handling #####################
# prepare the cache placeholder
# n_layers, 2, bsz, num_heads, src_len, head_dim
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float16) if kvcache else None
# handle multi-span kv-cache
new_masked_span = False
@ -1011,7 +1011,7 @@ class VoiceCraft(
# prepare the cache placeholder
# n_layers, 2, bsz, num_heads, src_len, head_dim
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float16) if kvcache else None
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
@ -1261,7 +1261,7 @@ class VoiceCraft(
# prepare the cache placeholder
# n_layers, 2, bsz, num_heads, src_len, head_dim
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float16) if kvcache else None
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")