Deduplicate VoiceCraftHF <> VoiceCraft

This commit is contained in:
Wauplin 2024-04-16 10:24:05 +02:00
parent 943211d751
commit e550f61409
No known key found for this signature in database
GPG Key ID: 9838FE02BECE1A02
2 changed files with 27 additions and 16 deletions

View File

@ -92,7 +92,7 @@ def load_models(whisper_backend_name, whisper_model_name, alignment_model_name,
transcribe_model = WhisperxModel(whisper_model_name, align_model) transcribe_model = WhisperxModel(whisper_model_name, align_model)
voicecraft_name = f"{voicecraft_model_name}.pth" voicecraft_name = f"{voicecraft_model_name}.pth"
model = voicecraft.VoiceCraftHF.from_pretrained(f"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}") model = voicecraft.VoiceCraft.from_pretrained(f"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}")
phn2num = model.args.phn2num phn2num = model.args.phn2num
config = model.args config = model.args
model.to(device) model.to(device)

View File

@ -3,6 +3,7 @@ import random
import numpy as np import numpy as np
import logging import logging
import argparse, copy import argparse, copy
from typing import Dict, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -86,9 +87,31 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
class VoiceCraft(nn.Module): class VoiceCraft(
def __init__(self, args): nn.Module,
PyTorchModelHubMixin,
library_name="voicecraft",
repo_url="https://github.com/jasonppy/VoiceCraft",
tags=["text-to-speech"],
):
def __new__(cls, args: Optional[Namespace] = None, config: Optional[Dict] = None, **kwargs) -> "VoiceCraft":
# If initialized from Namespace args => convert to dict config for 'PyTorchModelHubMixin' to serialize it as config.json
# Won't affect instance initialization
if args is not None:
if config is not None:
raise ValueError("Cannot provide both `args` and `config`.")
config = vars(args)
return super().__new__(cls, args=args, config=config, **kwargs)
def __init__(self, args: Optional[Namespace] = None, config: Optional[Dict] = None):
super().__init__() super().__init__()
# If loaded from HF Hub => convert config.json to Namespace args before initializing
if args is None:
if config is None:
raise ValueError("Either `args` or `config` must be provided.")
args = Namespace(**config)
self.args = copy.copy(args) self.args = copy.copy(args)
self.pattern = DelayedPatternProvider(n_q=self.args.n_codebooks) self.pattern = DelayedPatternProvider(n_q=self.args.n_codebooks)
if not getattr(self.args, "special_first", False): if not getattr(self.args, "special_first", False):
@ -100,7 +123,7 @@ class VoiceCraft(nn.Module):
if self.args.eos > 0: if self.args.eos > 0:
assert self.args.eos != self.args.audio_pad_token and self.args.eos != self.args.empty_token, self.args.eos assert self.args.eos != self.args.audio_pad_token and self.args.eos != self.args.empty_token, self.args.eos
self.eos = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eos, dtype=torch.long), requires_grad=False) # [K 1] self.eos = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eos, dtype=torch.long), requires_grad=False) # [K 1]
if type(self.args.audio_vocab_size) == str: if isinstance(self.args.audio_vocab_size, str):
self.args.audio_vocab_size = eval(self.args.audio_vocab_size) self.args.audio_vocab_size = eval(self.args.audio_vocab_size)
self.n_text_tokens = self.args.text_vocab_size + 1 self.n_text_tokens = self.args.text_vocab_size + 1
@ -1414,15 +1437,3 @@ class VoiceCraft(nn.Module):
flatten_gen = flatten_gen - int(self.args.n_special) flatten_gen = flatten_gen - int(self.args.n_special)
return res, flatten_gen[0].unsqueeze(0) return res, flatten_gen[0].unsqueeze(0)
class VoiceCraftHF(
VoiceCraft,
PyTorchModelHubMixin,
repo_url="https://github.com/jasonppy/VoiceCraft",
tags=["text-to-speech"],
library_name="voicecraft"
):
def __init__(self, config: dict):
args = Namespace(**config)
super().__init__(args)