Float16 KV Cache in voicecraft.py

This commit is contained in:
Forkoz 2024-04-05 17:52:28 +00:00 committed by GitHub
parent 2506954b64
commit 6dda1a4f32
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 4 deletions

View File

@ -678,7 +678,7 @@ class VoiceCraft(nn.Module):
##################### silence repetition handling #####################
# prepare the cache placeholder
# n_layers, 2, bsz, num_heads, src_len, head_dim
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float16) if kvcache else None
# handle multi-span kv-cache
new_masked_span = False
@ -978,7 +978,7 @@ class VoiceCraft(nn.Module):
# prepare the cache placeholder
# n_layers, 2, bsz, num_heads, src_len, head_dim
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float16) if kvcache else None
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
@ -1228,7 +1228,7 @@ class VoiceCraft(nn.Module):
# prepare the cache placeholder
# n_layers, 2, bsz, num_heads, src_len, head_dim
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float16) if kvcache else None
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
@ -1403,4 +1403,4 @@ class VoiceCraft(nn.Module):
res = res - int(self.args.n_special)
flatten_gen = flatten_gen - int(self.args.n_special)
return res, flatten_gen[0].unsqueeze(0)
return res, flatten_gen[0].unsqueeze(0)