Add class

This commit is contained in:
Niels 2024-04-07 20:17:52 +02:00
parent a11e1b8f3c
commit 92b283c741
1 changed files with 7 additions and 2 deletions

View File

@ -85,7 +85,7 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
class VoiceCraft(nn.Module, PyTorchModelHubMixin): class VoiceCraft(nn.Module):
def __init__(self, args): def __init__(self, args):
super().__init__() super().__init__()
self.args = copy.copy(args) self.args = copy.copy(args)
@ -1410,4 +1410,9 @@ class VoiceCraft(nn.Module, PyTorchModelHubMixin):
res = res - int(self.args.n_special) res = res - int(self.args.n_special)
flatten_gen = flatten_gen - int(self.args.n_special) flatten_gen = flatten_gen - int(self.args.n_special)
return res, flatten_gen[0].unsqueeze(0) return res, flatten_gen[0].unsqueeze(0)
class VoiceCraftHF(VoiceCraft, PyTorchModelHubMixin):
def __init__(self, config: dict):
super().__init__(config)