Merge branch 'jasonppy:master' into master

This commit is contained in:
Chenxi 2024-04-17 16:27:36 +01:00 committed by GitHub
commit 6e5382584c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 31 additions and 14 deletions

View File

@ -92,7 +92,7 @@ def load_models(whisper_backend_name, whisper_model_name, alignment_model_name,
transcribe_model = WhisperxModel(whisper_model_name, align_model)
voicecraft_name = f"{voicecraft_model_name}.pth"
model = voicecraft.VoiceCraftHF.from_pretrained(f"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}")
model = voicecraft.VoiceCraft.from_pretrained(f"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}")
phn2num = model.args.phn2num
config = model.args
model.to(device)

View File

@ -203,8 +203,8 @@
"voicecraft_name=\"giga330M.pth\" # or gigaHalfLibri330M_TTSEnhanced_max16s.pth, giga830M.pth\n",
"\n",
"# the new way of loading the model, with huggingface, recommended\n",
"from models.voicecraft import VoiceCraftHF\n",
"model = VoiceCraftHF.from_pretrained(f\"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}\")\n",
"from models import voicecraft\n",
"model = voicecraft.VoiceCraft.from_pretrained(f\"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}\")\n",
"phn2num = model.args.phn2num\n",
"config = vars(model.args)\n",
"model.to(device)\n",

View File

@ -74,8 +74,8 @@
"voicecraft_name=\"giga330M.pth\" # or gigaHalfLibri330M_TTSEnhanced_max16s.pth, giga830M.pth\n",
"\n",
"# the new way of loading the model, with huggingface, recommended\n",
"from models.voicecraft import VoiceCraftHF\n",
"model = VoiceCraftHF.from_pretrained(f\"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}\")\n",
"from models import voicecraft\n",
"model = voicecraft.VoiceCraft.from_pretrained(f\"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}\")\n",
"phn2num = model.args.phn2num\n",
"config = vars(model.args)\n",
"model.to(device)\n",

View File

@ -3,6 +3,7 @@ import random
import numpy as np
import logging
import argparse, copy
from typing import Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -86,9 +87,31 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
class VoiceCraft(nn.Module):
def __init__(self, args):
class VoiceCraft(
nn.Module,
PyTorchModelHubMixin,
library_name="voicecraft",
repo_url="https://github.com/jasonppy/VoiceCraft",
tags=["text-to-speech"],
):
def __new__(cls, args: Optional[Namespace] = None, config: Optional[Dict] = None, **kwargs) -> "VoiceCraft":
# If initialized from Namespace args => convert to dict config for 'PyTorchModelHubMixin' to serialize it as config.json
# Won't affect instance initialization
if args is not None:
if config is not None:
raise ValueError("Cannot provide both `args` and `config`.")
config = vars(args)
return super().__new__(cls, args=args, config=config, **kwargs)
def __init__(self, args: Optional[Namespace] = None, config: Optional[Dict] = None):
super().__init__()
# If loaded from HF Hub => convert config.json to Namespace args before initializing
if args is None:
if config is None:
raise ValueError("Either `args` or `config` must be provided.")
args = Namespace(**config)
self.args = copy.copy(args)
self.pattern = DelayedPatternProvider(n_q=self.args.n_codebooks)
if not getattr(self.args, "special_first", False):
@ -100,7 +123,7 @@ class VoiceCraft(nn.Module):
if self.args.eos > 0:
assert self.args.eos != self.args.audio_pad_token and self.args.eos != self.args.empty_token, self.args.eos
self.eos = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eos, dtype=torch.long), requires_grad=False) # [K 1]
if type(self.args.audio_vocab_size) == str:
if isinstance(self.args.audio_vocab_size, str):
self.args.audio_vocab_size = eval(self.args.audio_vocab_size)
self.n_text_tokens = self.args.text_vocab_size + 1
@ -1414,9 +1437,3 @@ class VoiceCraft(nn.Module):
flatten_gen = flatten_gen - int(self.args.n_special)
return res, flatten_gen[0].unsqueeze(0)
class VoiceCraftHF(VoiceCraft, PyTorchModelHubMixin, repo_url="https://github.com/jasonppy/VoiceCraft", tags=["Text-to-Speech", "VoiceCraft"]):
def __init__(self, config: dict):
args = Namespace(**config)
super().__init__(args)