|
|
|
@ -3,6 +3,7 @@ import random
|
|
|
|
|
import numpy as np
|
|
|
|
|
import logging
|
|
|
|
|
import argparse, copy
|
|
|
|
|
from typing import Dict, Optional
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
@ -86,9 +87,31 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VoiceCraft(nn.Module):
|
|
|
|
|
def __init__(self, args):
|
|
|
|
|
class VoiceCraft(
|
|
|
|
|
nn.Module,
|
|
|
|
|
PyTorchModelHubMixin,
|
|
|
|
|
library_name="voicecraft",
|
|
|
|
|
repo_url="https://github.com/jasonppy/VoiceCraft",
|
|
|
|
|
tags=["text-to-speech"],
|
|
|
|
|
):
|
|
|
|
|
def __new__(cls, args: Optional[Namespace] = None, config: Optional[Dict] = None, **kwargs) -> "VoiceCraft":
|
|
|
|
|
# If initialized from Namespace args => convert to dict config for 'PyTorchModelHubMixin' to serialize it as config.json
|
|
|
|
|
# Won't affect instance initialization
|
|
|
|
|
if args is not None:
|
|
|
|
|
if config is not None:
|
|
|
|
|
raise ValueError("Cannot provide both `args` and `config`.")
|
|
|
|
|
config = vars(args)
|
|
|
|
|
return super().__new__(cls, args=args, config=config, **kwargs)
|
|
|
|
|
|
|
|
|
|
def __init__(self, args: Optional[Namespace] = None, config: Optional[Dict] = None):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
# If loaded from HF Hub => convert config.json to Namespace args before initializing
|
|
|
|
|
if args is None:
|
|
|
|
|
if config is None:
|
|
|
|
|
raise ValueError("Either `args` or `config` must be provided.")
|
|
|
|
|
args = Namespace(**config)
|
|
|
|
|
|
|
|
|
|
self.args = copy.copy(args)
|
|
|
|
|
self.pattern = DelayedPatternProvider(n_q=self.args.n_codebooks)
|
|
|
|
|
if not getattr(self.args, "special_first", False):
|
|
|
|
@ -100,7 +123,7 @@ class VoiceCraft(nn.Module):
|
|
|
|
|
if self.args.eos > 0:
|
|
|
|
|
assert self.args.eos != self.args.audio_pad_token and self.args.eos != self.args.empty_token, self.args.eos
|
|
|
|
|
self.eos = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eos, dtype=torch.long), requires_grad=False) # [K 1]
|
|
|
|
|
if type(self.args.audio_vocab_size) == str:
|
|
|
|
|
if isinstance(self.args.audio_vocab_size, str):
|
|
|
|
|
self.args.audio_vocab_size = eval(self.args.audio_vocab_size)
|
|
|
|
|
|
|
|
|
|
self.n_text_tokens = self.args.text_vocab_size + 1
|
|
|
|
@ -1414,15 +1437,3 @@ class VoiceCraft(nn.Module):
|
|
|
|
|
flatten_gen = flatten_gen - int(self.args.n_special)
|
|
|
|
|
|
|
|
|
|
return res, flatten_gen[0].unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VoiceCraftHF(
|
|
|
|
|
VoiceCraft,
|
|
|
|
|
PyTorchModelHubMixin,
|
|
|
|
|
repo_url="https://github.com/jasonppy/VoiceCraft",
|
|
|
|
|
tags=["Text-to-Speech"],
|
|
|
|
|
library_name="voicecraft"
|
|
|
|
|
):
|
|
|
|
|
def __init__(self, config: dict):
|
|
|
|
|
args = Namespace(**config)
|
|
|
|
|
super().__init__(args)
|
|
|
|
|