diff --git a/models/voicecraft.py b/models/voicecraft.py index 8ea85ad..ee8e5a7 100644 --- a/models/voicecraft.py +++ b/models/voicecraft.py @@ -1414,7 +1414,9 @@ class VoiceCraft(nn.Module): return res, flatten_gen[0].unsqueeze(0) -class VoiceCraftHF(VoiceCraft, PyTorchModelHubMixin): +class VoiceCraftHF(VoiceCraft, PyTorchModelHubMixin, + repo_url="https://github.com/jasonppy/VoiceCraft", + tags=["Text-to-Speech, VoiceCraft"]): def __init__(self, config: dict): args = Namespace(**config) super().__init__(args) \ No newline at end of file