diff --git a/models/voicecraft.py b/models/voicecraft.py index f090c66..8f87264 100644 --- a/models/voicecraft.py +++ b/models/voicecraft.py @@ -85,7 +85,7 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0): -class VoiceCraft(nn.Module, PyTorchModelHubMixin): +class VoiceCraft(nn.Module): def __init__(self, args): super().__init__() self.args = copy.copy(args) @@ -1410,4 +1410,9 @@ class VoiceCraft(nn.Module, PyTorchModelHubMixin): res = res - int(self.args.n_special) flatten_gen = flatten_gen - int(self.args.n_special) - return res, flatten_gen[0].unsqueeze(0) \ No newline at end of file + return res, flatten_gen[0].unsqueeze(0) + + +class VoiceCraftHF(VoiceCraft, PyTorchModelHubMixin): + def __init__(self, config: dict): + super().__init__(config) \ No newline at end of file