diff --git a/gradio_app.py b/gradio_app.py index 41f64a5..8d2dae6 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -92,7 +92,7 @@ def load_models(whisper_backend_name, whisper_model_name, alignment_model_name, transcribe_model = WhisperxModel(whisper_model_name, align_model) voicecraft_name = f"{voicecraft_model_name}.pth" - model = voicecraft.VoiceCraftHF.from_pretrained(f"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}") + model = voicecraft.VoiceCraft.from_pretrained(f"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}") phn2num = model.args.phn2num config = model.args model.to(device) diff --git a/models/voicecraft.py b/models/voicecraft.py index a68a38e..508e55f 100644 --- a/models/voicecraft.py +++ b/models/voicecraft.py @@ -3,6 +3,7 @@ import random import numpy as np import logging import argparse, copy +from typing import Dict, Optional import torch import torch.nn as nn import torch.nn.functional as F @@ -86,9 +87,31 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0): -class VoiceCraft(nn.Module): - def __init__(self, args): +class VoiceCraft( + nn.Module, + PyTorchModelHubMixin, + library_name="voicecraft", + repo_url="https://github.com/jasonppy/VoiceCraft", + tags=["text-to-speech"], + ): + def __new__(cls, args: Optional[Namespace] = None, config: Optional[Dict] = None, **kwargs) -> "VoiceCraft": + # If initialized from Namespace args => convert to dict config for 'PyTorchModelHubMixin' to serialize it as config.json + # Won't affect instance initialization + if args is not None: + if config is not None: + raise ValueError("Cannot provide both `args` and `config`.") + config = vars(args) + return super().__new__(cls, args=args, config=config, **kwargs) + + def __init__(self, args: Optional[Namespace] = None, config: Optional[Dict] = None): super().__init__() + + # If loaded from HF Hub => convert config.json to Namespace args before initializing + if args is None: + if config is None: + raise ValueError("Either `args` or `config` must be provided.") + args = Namespace(**config) + self.args = copy.copy(args) self.pattern = DelayedPatternProvider(n_q=self.args.n_codebooks) if not getattr(self.args, "special_first", False): @@ -100,7 +123,7 @@ class VoiceCraft(nn.Module): if self.args.eos > 0: assert self.args.eos != self.args.audio_pad_token and self.args.eos != self.args.empty_token, self.args.eos self.eos = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eos, dtype=torch.long), requires_grad=False) # [K 1] - if type(self.args.audio_vocab_size) == str: + if isinstance(self.args.audio_vocab_size, str): self.args.audio_vocab_size = eval(self.args.audio_vocab_size) self.n_text_tokens = self.args.text_vocab_size + 1 @@ -1414,15 +1437,3 @@ class VoiceCraft(nn.Module): flatten_gen = flatten_gen - int(self.args.n_special) return res, flatten_gen[0].unsqueeze(0) - - -class VoiceCraftHF( - VoiceCraft, - PyTorchModelHubMixin, - repo_url="https://github.com/jasonppy/VoiceCraft", - tags=["text-to-speech"], - library_name="voicecraft" - ): - def __init__(self, config: dict): - args = Namespace(**config) - super().__init__(args)