Deduplicate VoiceCraftHF <> VoiceCraft
This commit is contained in:
parent
943211d751
commit
e550f61409
|
@ -92,7 +92,7 @@ def load_models(whisper_backend_name, whisper_model_name, alignment_model_name,
|
|||
transcribe_model = WhisperxModel(whisper_model_name, align_model)
|
||||
|
||||
voicecraft_name = f"{voicecraft_model_name}.pth"
|
||||
model = voicecraft.VoiceCraftHF.from_pretrained(f"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}")
|
||||
model = voicecraft.VoiceCraft.from_pretrained(f"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}")
|
||||
phn2num = model.args.phn2num
|
||||
config = model.args
|
||||
model.to(device)
|
||||
|
|
|
@ -3,6 +3,7 @@ import random
|
|||
import numpy as np
|
||||
import logging
|
||||
import argparse, copy
|
||||
from typing import Dict, Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
@ -86,9 +87,31 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
|
|||
|
||||
|
||||
|
||||
class VoiceCraft(nn.Module):
|
||||
def __init__(self, args):
|
||||
class VoiceCraft(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="voicecraft",
|
||||
repo_url="https://github.com/jasonppy/VoiceCraft",
|
||||
tags=["text-to-speech"],
|
||||
):
|
||||
def __new__(cls, args: Optional[Namespace] = None, config: Optional[Dict] = None, **kwargs) -> "VoiceCraft":
|
||||
# If initialized from Namespace args => convert to dict config for 'PyTorchModelHubMixin' to serialize it as config.json
|
||||
# Won't affect instance initialization
|
||||
if args is not None:
|
||||
if config is not None:
|
||||
raise ValueError("Cannot provide both `args` and `config`.")
|
||||
config = vars(args)
|
||||
return super().__new__(cls, args=args, config=config, **kwargs)
|
||||
|
||||
def __init__(self, args: Optional[Namespace] = None, config: Optional[Dict] = None):
|
||||
super().__init__()
|
||||
|
||||
# If loaded from HF Hub => convert config.json to Namespace args before initializing
|
||||
if args is None:
|
||||
if config is None:
|
||||
raise ValueError("Either `args` or `config` must be provided.")
|
||||
args = Namespace(**config)
|
||||
|
||||
self.args = copy.copy(args)
|
||||
self.pattern = DelayedPatternProvider(n_q=self.args.n_codebooks)
|
||||
if not getattr(self.args, "special_first", False):
|
||||
|
@ -100,7 +123,7 @@ class VoiceCraft(nn.Module):
|
|||
if self.args.eos > 0:
|
||||
assert self.args.eos != self.args.audio_pad_token and self.args.eos != self.args.empty_token, self.args.eos
|
||||
self.eos = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eos, dtype=torch.long), requires_grad=False) # [K 1]
|
||||
if type(self.args.audio_vocab_size) == str:
|
||||
if isinstance(self.args.audio_vocab_size, str):
|
||||
self.args.audio_vocab_size = eval(self.args.audio_vocab_size)
|
||||
|
||||
self.n_text_tokens = self.args.text_vocab_size + 1
|
||||
|
@ -1414,15 +1437,3 @@ class VoiceCraft(nn.Module):
|
|||
flatten_gen = flatten_gen - int(self.args.n_special)
|
||||
|
||||
return res, flatten_gen[0].unsqueeze(0)
|
||||
|
||||
|
||||
class VoiceCraftHF(
|
||||
VoiceCraft,
|
||||
PyTorchModelHubMixin,
|
||||
repo_url="https://github.com/jasonppy/VoiceCraft",
|
||||
tags=["text-to-speech"],
|
||||
library_name="voicecraft"
|
||||
):
|
||||
def __init__(self, config: dict):
|
||||
args = Namespace(**config)
|
||||
super().__init__(args)
|
||||
|
|
Loading…
Reference in New Issue