Merge branch 'jasonppy:master' into master

This commit is contained in:
Chenxi 2024-04-17 16:27:36 +01:00 committed by GitHub
commit 6e5382584c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 31 additions and 14 deletions

View File

@ -92,7 +92,7 @@ def load_models(whisper_backend_name, whisper_model_name, alignment_model_name,
transcribe_model = WhisperxModel(whisper_model_name, align_model) transcribe_model = WhisperxModel(whisper_model_name, align_model)
voicecraft_name = f"{voicecraft_model_name}.pth" voicecraft_name = f"{voicecraft_model_name}.pth"
model = voicecraft.VoiceCraftHF.from_pretrained(f"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}") model = voicecraft.VoiceCraft.from_pretrained(f"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}")
phn2num = model.args.phn2num phn2num = model.args.phn2num
config = model.args config = model.args
model.to(device) model.to(device)

View File

@ -203,8 +203,8 @@
"voicecraft_name=\"giga330M.pth\" # or gigaHalfLibri330M_TTSEnhanced_max16s.pth, giga830M.pth\n", "voicecraft_name=\"giga330M.pth\" # or gigaHalfLibri330M_TTSEnhanced_max16s.pth, giga830M.pth\n",
"\n", "\n",
"# the new way of loading the model, with huggingface, recommended\n", "# the new way of loading the model, with huggingface, recommended\n",
"from models.voicecraft import VoiceCraftHF\n", "from models import voicecraft\n",
"model = VoiceCraftHF.from_pretrained(f\"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}\")\n", "model = voicecraft.VoiceCraft.from_pretrained(f\"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}\")\n",
"phn2num = model.args.phn2num\n", "phn2num = model.args.phn2num\n",
"config = vars(model.args)\n", "config = vars(model.args)\n",
"model.to(device)\n", "model.to(device)\n",

View File

@ -74,8 +74,8 @@
"voicecraft_name=\"giga330M.pth\" # or gigaHalfLibri330M_TTSEnhanced_max16s.pth, giga830M.pth\n", "voicecraft_name=\"giga330M.pth\" # or gigaHalfLibri330M_TTSEnhanced_max16s.pth, giga830M.pth\n",
"\n", "\n",
"# the new way of loading the model, with huggingface, recommended\n", "# the new way of loading the model, with huggingface, recommended\n",
"from models.voicecraft import VoiceCraftHF\n", "from models import voicecraft\n",
"model = VoiceCraftHF.from_pretrained(f\"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}\")\n", "model = voicecraft.VoiceCraft.from_pretrained(f\"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}\")\n",
"phn2num = model.args.phn2num\n", "phn2num = model.args.phn2num\n",
"config = vars(model.args)\n", "config = vars(model.args)\n",
"model.to(device)\n", "model.to(device)\n",

View File

@ -3,6 +3,7 @@ import random
import numpy as np import numpy as np
import logging import logging
import argparse, copy import argparse, copy
from typing import Dict, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -86,9 +87,31 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
class VoiceCraft(nn.Module): class VoiceCraft(
def __init__(self, args): nn.Module,
PyTorchModelHubMixin,
library_name="voicecraft",
repo_url="https://github.com/jasonppy/VoiceCraft",
tags=["text-to-speech"],
):
def __new__(cls, args: Optional[Namespace] = None, config: Optional[Dict] = None, **kwargs) -> "VoiceCraft":
# If initialized from Namespace args => convert to dict config for 'PyTorchModelHubMixin' to serialize it as config.json
# Won't affect instance initialization
if args is not None:
if config is not None:
raise ValueError("Cannot provide both `args` and `config`.")
config = vars(args)
return super().__new__(cls, args=args, config=config, **kwargs)
def __init__(self, args: Optional[Namespace] = None, config: Optional[Dict] = None):
super().__init__() super().__init__()
# If loaded from HF Hub => convert config.json to Namespace args before initializing
if args is None:
if config is None:
raise ValueError("Either `args` or `config` must be provided.")
args = Namespace(**config)
self.args = copy.copy(args) self.args = copy.copy(args)
self.pattern = DelayedPatternProvider(n_q=self.args.n_codebooks) self.pattern = DelayedPatternProvider(n_q=self.args.n_codebooks)
if not getattr(self.args, "special_first", False): if not getattr(self.args, "special_first", False):
@ -100,7 +123,7 @@ class VoiceCraft(nn.Module):
if self.args.eos > 0: if self.args.eos > 0:
assert self.args.eos != self.args.audio_pad_token and self.args.eos != self.args.empty_token, self.args.eos assert self.args.eos != self.args.audio_pad_token and self.args.eos != self.args.empty_token, self.args.eos
self.eos = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eos, dtype=torch.long), requires_grad=False) # [K 1] self.eos = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eos, dtype=torch.long), requires_grad=False) # [K 1]
if type(self.args.audio_vocab_size) == str: if isinstance(self.args.audio_vocab_size, str):
self.args.audio_vocab_size = eval(self.args.audio_vocab_size) self.args.audio_vocab_size = eval(self.args.audio_vocab_size)
self.n_text_tokens = self.args.text_vocab_size + 1 self.n_text_tokens = self.args.text_vocab_size + 1
@ -1414,9 +1437,3 @@ class VoiceCraft(nn.Module):
flatten_gen = flatten_gen - int(self.args.n_special) flatten_gen = flatten_gen - int(self.args.n_special)
return res, flatten_gen[0].unsqueeze(0) return res, flatten_gen[0].unsqueeze(0)
class VoiceCraftHF(VoiceCraft, PyTorchModelHubMixin, repo_url="https://github.com/jasonppy/VoiceCraft", tags=["Text-to-Speech", "VoiceCraft"]):
def __init__(self, config: dict):
args = Namespace(**config)
super().__init__(args)