diff --git a/models/voicecraft.py b/models/voicecraft.py index 9bb3393..8811160 100644 --- a/models/voicecraft.py +++ b/models/voicecraft.py @@ -1416,7 +1416,13 @@ class VoiceCraft(nn.Module): return res, flatten_gen[0].unsqueeze(0) -class VoiceCraftHF(VoiceCraft, PyTorchModelHubMixin, repo_url="https://github.com/jasonppy/VoiceCraft", tags=["Text-to-Speech", "VoiceCraft"]): +class VoiceCraftHF( + VoiceCraft, + PyTorchModelHubMixin, + repo_url="https://github.com/jasonppy/VoiceCraft", + tags=["Text-to-Speech"], + library_name="voicecraft" + ): def __init__(self, config: dict): args = Namespace(**config) - super().__init__(args) \ No newline at end of file + super().__init__(args)