Compare commits

...

4 Commits

Author SHA1 Message Date
jason-on-salt-a40 9dab235647 better hf integration 2024-04-16 08:55:35 -07:00
Wauplin e550f61409
Deduplicate VoiceCraftHF <> VoiceCraft 2024-04-16 10:24:05 +02:00
Lucain 943211d751
Update models/voicecraft.py
Co-authored-by: Julien Chaumond <julien@huggingface.co>
2024-04-16 08:47:49 +02:00
Lucain 77df5104b0
Tweak VoiceCraft x HF integration
This PR tweaks the HF integrations:
- `VoiceCraft` tag is lowercased to `voicecraft` => not a hard requirement but makes it more consistent with other libraries on the Hub.
- `voicecraft` is set as the `library_name` instead of simply a tag. This is better for taxonomy on the Hub.

Regarding the integration, I also opened https://github.com/huggingface/huggingface.js/pull/626 to make it more official on the Hub. In particular, there will now be an official `</> Use in VoiceCraft`  button in all voicecraft models that display the code snippet to load the model. This should help users getting started with the model. It will also add a link to the voicecraft repo for the installation guide.

cc @NielsRogge who opened https://github.com/jasonppy/VoiceCraft/pull/78
2024-04-15 11:51:25 +02:00
4 changed files with 31 additions and 14 deletions

View File

@ -92,7 +92,7 @@ def load_models(whisper_backend_name, whisper_model_name, alignment_model_name,
transcribe_model = WhisperxModel(whisper_model_name, align_model)
voicecraft_name = f"{voicecraft_model_name}.pth"
model = voicecraft.VoiceCraftHF.from_pretrained(f"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}")
model = voicecraft.VoiceCraft.from_pretrained(f"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}")
phn2num = model.args.phn2num
config = model.args
model.to(device)

View File

@ -203,8 +203,8 @@
"voicecraft_name=\"giga330M.pth\" # or gigaHalfLibri330M_TTSEnhanced_max16s.pth, giga830M.pth\n",
"\n",
"# the new way of loading the model, with huggingface, recommended\n",
"from models.voicecraft import VoiceCraftHF\n",
"model = VoiceCraftHF.from_pretrained(f\"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}\")\n",
"from models import voicecraft\n",
"model = voicecraft.VoiceCraft.from_pretrained(f\"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}\")\n",
"phn2num = model.args.phn2num\n",
"config = vars(model.args)\n",
"model.to(device)\n",

View File

@ -74,8 +74,8 @@
"voicecraft_name=\"giga330M.pth\" # or gigaHalfLibri330M_TTSEnhanced_max16s.pth, giga830M.pth\n",
"\n",
"# the new way of loading the model, with huggingface, recommended\n",
"from models.voicecraft import VoiceCraftHF\n",
"model = VoiceCraftHF.from_pretrained(f\"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}\")\n",
"from models import voicecraft\n",
"model = voicecraft.VoiceCraft.from_pretrained(f\"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}\")\n",
"phn2num = model.args.phn2num\n",
"config = vars(model.args)\n",
"model.to(device)\n",

View File

@ -3,6 +3,7 @@ import random
import numpy as np
import logging
import argparse, copy
from typing import Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -86,9 +87,31 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
class VoiceCraft(nn.Module):
def __init__(self, args):
class VoiceCraft(
nn.Module,
PyTorchModelHubMixin,
library_name="voicecraft",
repo_url="https://github.com/jasonppy/VoiceCraft",
tags=["text-to-speech"],
):
def __new__(cls, args: Optional[Namespace] = None, config: Optional[Dict] = None, **kwargs) -> "VoiceCraft":
# If initialized from Namespace args => convert to dict config for 'PyTorchModelHubMixin' to serialize it as config.json
# Won't affect instance initialization
if args is not None:
if config is not None:
raise ValueError("Cannot provide both `args` and `config`.")
config = vars(args)
return super().__new__(cls, args=args, config=config, **kwargs)
def __init__(self, args: Optional[Namespace] = None, config: Optional[Dict] = None):
super().__init__()
# If loaded from HF Hub => convert config.json to Namespace args before initializing
if args is None:
if config is None:
raise ValueError("Either `args` or `config` must be provided.")
args = Namespace(**config)
self.args = copy.copy(args)
self.pattern = DelayedPatternProvider(n_q=self.args.n_codebooks)
if not getattr(self.args, "special_first", False):
@ -100,7 +123,7 @@ class VoiceCraft(nn.Module):
if self.args.eos > 0:
assert self.args.eos != self.args.audio_pad_token and self.args.eos != self.args.empty_token, self.args.eos
self.eos = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eos, dtype=torch.long), requires_grad=False) # [K 1]
if type(self.args.audio_vocab_size) == str:
if isinstance(self.args.audio_vocab_size, str):
self.args.audio_vocab_size = eval(self.args.audio_vocab_size)
self.n_text_tokens = self.args.text_vocab_size + 1
@ -1414,9 +1437,3 @@ class VoiceCraft(nn.Module):
flatten_gen = flatten_gen - int(self.args.n_special)
return res, flatten_gen[0].unsqueeze(0)
class VoiceCraftHF(VoiceCraft, PyTorchModelHubMixin, repo_url="https://github.com/jasonppy/VoiceCraft", tags=["Text-to-Speech", "VoiceCraft"]):
def __init__(self, config: dict):
args = Namespace(**config)
super().__init__(args)