Compare commits

...

22 Commits

Author SHA1 Message Date
Forkoz 348ffd59ef
Merge 6dda1a4f32 into 4a3a8f11a7 2024-04-23 13:14:37 -04:00
pyp_l40 4a3a8f11a7 small fix 2024-04-22 14:39:05 -05:00
pyp_l40 8d1177149b Replicate 2024-04-22 14:26:05 -05:00
pyp_l40 4ff9930b8e Merge branch 'chenxwh/master' 2024-04-22 13:47:49 -05:00
pyp_l40 96f6f9fc7a better handle numbers 2024-04-22 11:56:39 -05:00
chenxwh ee3955d57e Merge updates from original repository 2024-04-21 22:36:16 +00:00
chenxwh 87f4fa5d21 update 2024-04-21 22:30:56 +00:00
chenxwh 2a2ee984b6 update audiocraft install 2024-04-19 10:49:04 +00:00
chenxwh 729d0ec69e update audiocraft install 2024-04-19 10:48:15 +00:00
chenxwh ef3dd8285b Merge branch 'master' of https://github.com/chenxwh/VoiceCraft 2024-04-19 10:46:05 +00:00
chenxwh 9746a1f60c update with whisperx 2024-04-19 10:46:00 +00:00
Chenxi 4bd7b83b57
Merge branch 'jasonppy:master' into master 2024-04-19 11:45:24 +01:00
Chenxi 6e5382584c
Merge branch 'jasonppy:master' into master 2024-04-17 16:27:36 +01:00
chenxwh 0da8ee4b7a update replicate demo 2024-04-14 12:15:23 +00:00
Chenxi e3fc926ca4
Merge branch 'jasonppy:master' into master 2024-04-14 09:31:28 +01:00
chenxwh 0c6942fd2a Merge branch 'master' of https://github.com/chenxwh/VoiceCraft 2024-04-12 14:23:21 +00:00
chenxwh f649f9216b Merged changes from upstream 2024-04-12 14:18:51 +00:00
Chenxi 1e2f8391a7
Merge branch 'jasonppy:master' into master 2024-04-05 21:31:39 +01:00
chenxwh b8eca5a2d4 replicate demo 2024-04-05 17:58:09 +00:00
Forkoz 6dda1a4f32
Float16 KV Cache in voicecraft.py 2024-04-05 17:52:28 +00:00
chenxwh 023d4b1c6c replicate demo 2024-04-05 17:23:39 +00:00
chenxwh 49a648fa54 Replicate TTS v1 demo 2024-04-05 15:20:11 +00:00
9 changed files with 457 additions and 9 deletions

17
.dockerignore Normal file
View File

@ -0,0 +1,17 @@
# The .dockerignore file excludes files from the container build process.
#
# https://docs.docker.com/engine/reference/builder/#dockerignore-file
# Exclude Git files
.git
.github
.gitignore
# Exclude Python cache files
__pycache__
.mypy_cache
.pytest_cache
.ruff_cache
# Exclude Python virtual environment
/venv

3
.gitignore vendored
View File

@ -29,4 +29,5 @@ src/audiocraft
!/demo/
!/demo/*
/demo/temp/*.txt
!/demo/temp/84_121550_000074_000000.txt
!/demo/temp/84_121550_000074_000000.txt
.cog/tmp/*

View File

@ -1,5 +1,6 @@
# VoiceCraft: Zero-Shot Speech Editing and Text-to-Speech in the Wild
[![Paper](https://img.shields.io/badge/arXiv-2301.12503-brightgreen.svg?style=flat-square)](https://jasonppy.github.io/assets/pdfs/VoiceCraft.pdf) [![githubio](https://img.shields.io/badge/GitHub.io-Audio_Samples-blue?logo=Github&style=flat-square)](https://jasonppy.github.io/VoiceCraft_web/) [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/pyp1/VoiceCraft_gradio) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1IOjpglQyMTO2C3Y94LD9FY0Ocn-RJRg6?usp=sharing)
[![Paper](https://img.shields.io/badge/arXiv-2403.16973-brightgreen.svg?style=flat-square)](https://arxiv.org/pdf/2403.16973.pdf) [![HuggingFace](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/pyp1/VoiceCraft_gradio) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1IOjpglQyMTO2C3Y94LD9FY0Ocn-RJRg6?usp=sharing) [![Replicate](https://replicate.com/cjwbw/voicecraft/badge)](https://replicate.com/cjwbw/voicecraft) [![YouTube demo](https://img.shields.io/youtube/views/eikybOi8iwU)](https://youtu.be/eikybOi8iwU) [![Demo page](https://img.shields.io/badge/Audio_Samples-blue?logo=Github&style=flat-square)](https://jasonppy.github.io/VoiceCraft_web/)
### TL;DR
VoiceCraft is a token infilling neural codec language model, that achieves state-of-the-art performance on both **speech editing** and **zero-shot text-to-speech (TTS)** on in-the-wild data including audiobooks, internet videos, and podcasts.
@ -18,6 +19,8 @@ When you are inside the docker image or you have installed all dependencies, Che
If you want to do model development such as training/finetuning, I recommend following [envrionment setup](#environment-setup) and [training](#training).
## News
:star: 04/22/2024: 330M/830M TTS Enhanced Models are up [here](https://huggingface.co/pyp1), load them through [`gradio_app.py`](./gradio_app.py) or [`inference_tts.ipynb`](./inference_tts.ipynb)! Replicate demo is up, major thanks to [@chenxwh](https://github.com/chenxwh)!
:star: 04/11/2024: VoiceCraft Gradio is now available on HuggingFace Spaces [here](https://huggingface.co/spaces/pyp1/VoiceCraft_gradio)! Major thanks to [@zuev-stepan](https://github.com/zuev-stepan), [@Sewlell](https://github.com/Sewlell), [@pgsoar](https://github.com/pgosar) [@Ph0rk0z](https://github.com/Ph0rk0z).
:star: 04/05/2024: I finetuned giga330M with the TTS objective on gigaspeech and 1/5 of librilight. Weights are [here](https://huggingface.co/pyp1/VoiceCraft/tree/main). Make sure maximal prompt + generation length <= 16 seconds (due to our limited compute, we had to drop utterances longer than 16s in training data). Even stronger models forthcomming, stay tuned!
@ -30,7 +33,7 @@ If you want to do model development such as training/finetuning, I recommend fol
- [x] Inference demo for speech editing and TTS
- [x] Training guidance
- [x] RealEdit dataset and training manifest
- [x] Model weights (giga330M.pth, giga830M.pth, and gigaHalfLibri330M_TTSEnhanced_max16s.pth)
- [x] Model weights
- [x] Better guidance on training/finetuning
- [x] Colab notebooks
- [x] HuggingFace Spaces demo
@ -210,7 +213,7 @@ We thank Feiteng for his [VALL-E reproduction](https://github.com/lifeiteng/vall
## Citation
```
@article{peng2024voicecraft,
author = {Peng, Puyuan and Huang, Po-Yao and Li, Daniel and Mohamed, Abdelrahman and Harwath, David},
author = {Peng, Puyuan and Huang, Po-Yao and Mohamed, Abdelrahman and Harwath, David},
title = {VoiceCraft: Zero-Shot Speech Editing and Text-to-Speech in the Wild},
journal = {arXiv},
year = {2024},

24
cog.yaml Normal file
View File

@ -0,0 +1,24 @@
# Configuration for Cog ⚙️
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
build:
gpu: true
system_packages:
- libgl1-mesa-glx
- libglib2.0-0
- ffmpeg
- espeak-ng
python_version: "3.11"
python_packages:
- torch==2.1.0
- torchaudio==2.1.0
- xformers
- phonemizer==3.2.1
- whisperx==3.1.1
- openai-whisper>=20231117
run:
- git clone https://github.com/facebookresearch/audiocraft && pip install -e ./audiocraft
- pip install "pydantic<2.0.0"
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.0/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
- mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ls960.pth" "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ls960.pth"
predict: "predict.py:Predictor"

View File

@ -1,4 +1,6 @@
import os
import re
from num2words import num2words
import gradio as gr
import torch
import torchaudio
@ -83,7 +85,7 @@ def load_models(whisper_backend_name, whisper_model_name, alignment_model_name,
elif voicecraft_model_name == "830M":
voicecraft_model_name = "giga830M"
elif voicecraft_model_name == "330M_TTSEnhanced":
voicecraft_model_name = "gigaHalfLibri330M_TTSEnhanced_max16s"
voicecraft_model_name = "330M_TTSEnhanced"
elif voicecraft_model_name == "830M_TTSEnhanced":
voicecraft_model_name = "830M_TTSEnhanced"
@ -201,6 +203,15 @@ def get_output_audio(audio_tensors, codec_audio_sr):
buffer.seek(0)
return buffer.read()
def replace_numbers_with_words(sentence):
sentence = re.sub(r'(\d+)', r' \1 ', sentence) # add spaces around numbers
def replace_with_words(match):
num = match.group(0)
try:
return num2words(num) # Convert numbers to words
except:
return num # In case num2words fails (unlikely with digits but just to be safe)
return re.sub(r'\b\d+\b', replace_with_words, sentence) # Regular expression that matches numbers
def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, temperature,
stop_repetition, sample_batch_size, kvcache, silence_tokens,
@ -213,6 +224,8 @@ def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p,
raise gr.Error("Can't use smart transcript: whisper transcript not found")
seed_everything(seed)
transcript = replace_numbers_with_words(transcript).replace(" ", " ").replace(" ", " ") # replace numbers with words, so that the phonemizer can do a better job
if mode == "Long TTS":
if split_text == "Newline":
sentences = transcript.split('\n')

View File

@ -4,3 +4,4 @@ openai-whisper>=20231117
aeneas>=1.7.3.0
whisperx>=3.1.1
huggingface_hub==0.22.2
num2words==0.5.13

View File

@ -71,7 +71,7 @@
"# load model, encodec, and phn2num\n",
"# # load model, tokenizer, and other necessary files\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"voicecraft_name=\"830M_TTSEnhanced.pth\" # or giga330M.pth, gigaHalfLibri330M_TTSEnhanced_max16s.pth, giga830M.pth\n",
"voicecraft_name=\"830M_TTSEnhanced.pth\" # or giga330M.pth, 330M_TTSEnhanced.pth, giga830M.pth\n",
"\n",
"# the new way of loading the model, with huggingface, recommended\n",
"from models import voicecraft\n",

View File

@ -711,7 +711,7 @@ class VoiceCraft(
##################### silence repetition handling #####################
# prepare the cache placeholder
# n_layers, 2, bsz, num_heads, src_len, head_dim
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float16) if kvcache else None
# handle multi-span kv-cache
new_masked_span = False
@ -1011,7 +1011,7 @@ class VoiceCraft(
# prepare the cache placeholder
# n_layers, 2, bsz, num_heads, src_len, head_dim
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float16) if kvcache else None
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
@ -1261,7 +1261,7 @@ class VoiceCraft(
# prepare the cache placeholder
# n_layers, 2, bsz, num_heads, src_len, head_dim
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None
past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float16) if kvcache else None
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")

389
predict.py Normal file
View File

@ -0,0 +1,389 @@
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md
import os
import time
import random
import getpass
import shutil
import subprocess
import torch
import numpy as np
import torchaudio
from cog import BasePredictor, Input, Path, BaseModel
os.environ["USER"] = getpass.getuser()
from data.tokenizer import (
AudioTokenizer,
TextTokenizer,
)
from models import voicecraft
from inference_tts_scale import inference_one_sample
from edit_utils import get_span
from inference_speech_editing_scale import (
inference_one_sample as inference_one_sample_editing,
)
MODEL_URL = "https://weights.replicate.delivery/default/pyp1/VoiceCraft-models.tar" # all the models are cached and uploaded to replicate.delivery for faster booting
MODEL_CACHE = "model_cache"
class ModelOutput(BaseModel):
whisper_transcript_orig_audio: str
generated_audio: Path
class WhisperxAlignModel:
def __init__(self):
from whisperx import load_align_model
self.model, self.metadata = load_align_model(
language_code="en", device="cuda:0"
)
def align(self, segments, audio_path):
from whisperx import align, load_audio
audio = load_audio(audio_path)
return align(
segments,
self.model,
self.metadata,
audio,
device="cuda:0",
return_char_alignments=False,
)["segments"]
class WhisperxModel:
def __init__(self, model_name, align_model: WhisperxAlignModel, device="cuda"):
from whisperx import load_model
# the model weights are cached from Systran/faster-whisper-base.en etc
self.model = load_model(
model_name,
device,
asr_options={
"suppress_numerals": True,
"max_new_tokens": None,
"clip_timestamps": None,
"hallucination_silence_threshold": None,
},
)
self.align_model = align_model
def transcribe(self, audio_path):
segments = self.model.transcribe(audio_path, language="en", batch_size=8)[
"segments"
]
return self.align_model.align(segments, audio_path)
def download_weights(url, dest):
start = time.time()
print("downloading url: ", url)
print("downloading to: ", dest)
subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
print("downloading took: ", time.time() - start)
class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
self.device = "cuda"
if not os.path.exists(MODEL_CACHE):
download_weights(MODEL_URL, MODEL_CACHE)
encodec_fn = f"{MODEL_CACHE}/encodec_4cb2048_giga.th"
self.models, self.ckpt, self.phn2num = {}, {}, {}
for voicecraft_name in [
"giga830M.pth",
"giga330M.pth",
"gigaHalfLibri330M_TTSEnhanced_max16s.pth",
]:
ckpt_fn = f"{MODEL_CACHE}/{voicecraft_name}"
self.ckpt[voicecraft_name] = torch.load(ckpt_fn, map_location="cpu")
self.models[voicecraft_name] = voicecraft.VoiceCraft(
self.ckpt[voicecraft_name]["config"]
)
self.models[voicecraft_name].load_state_dict(
self.ckpt[voicecraft_name]["model"]
)
self.models[voicecraft_name].to(self.device)
self.models[voicecraft_name].eval()
self.phn2num[voicecraft_name] = self.ckpt[voicecraft_name]["phn2num"]
self.text_tokenizer = TextTokenizer(backend="espeak")
self.audio_tokenizer = AudioTokenizer(signature=encodec_fn, device=self.device)
align_model = WhisperxAlignModel()
self.transcribe_models = {
k: WhisperxModel(f"{MODEL_CACHE}/whisperx_{k.split('.')[0]}", align_model)
for k in ["base.en", "small.en", "medium.en"]
}
def predict(
self,
task: str = Input(
description="Choose a task",
choices=[
"speech_editing-substitution",
"speech_editing-insertion",
"speech_editing-deletion",
"zero-shot text-to-speech",
],
default="zero-shot text-to-speech",
),
voicecraft_model: str = Input(
description="Choose a model",
choices=["giga830M.pth", "giga330M.pth", "giga330M_TTSEnhanced.pth"],
default="giga330M_TTSEnhanced.pth",
),
orig_audio: Path = Input(description="Original audio file"),
orig_transcript: str = Input(
description="Optionally provide the transcript of the input audio. Leave it blank to use the WhisperX model below to generate the transcript. Inaccurate transcription may lead to error TTS or speech editing",
default="",
),
whisperx_model: str = Input(
description="If orig_transcript is not provided above, choose a WhisperX model for generating the transcript. Inaccurate transcription may lead to error TTS or speech editing. You can modify the generated transcript and provide it directly to orig_transcript above",
choices=[
"base.en",
"small.en",
"medium.en",
],
default="base.en",
),
target_transcript: str = Input(
description="Transcript of the target audio file",
),
cut_off_sec: float = Input(
description="Only used for for zero-shot text-to-speech task. The first seconds of the original audio that are used for zero-shot text-to-speech. 3 sec of reference is generally enough for high quality voice cloning, but longer is generally better, try e.g. 3~6 sec",
default=3.01,
),
kvcache: int = Input(
description="Set to 0 to use less VRAM, but with slower inference",
choices=[0, 1],
default=1,
),
left_margin: float = Input(
description="Margin to the left of the editing segment",
default=0.08,
),
right_margin: float = Input(
description="Margin to the right of the editing segment",
default=0.08,
),
temperature: float = Input(
description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic. Do not recommend to change",
default=1,
),
top_p: float = Input(
description="Default value for TTS is 0.9, and 0.8 for speech editing",
default=0.9,
),
stop_repetition: int = Input(
default=3,
description="Default value for TTS is 3, and -1 for speech editing. -1 means do not adjust prob of silence tokens. if there are long silence or unnaturally stretched words, increase sample_batch_size to 2, 3 or even 4",
),
sample_batch_size: int = Input(
description="Default value for TTS is 4, and 1 for speech editing. The higher the number, the faster the output will be. Under the hood, the model will generate this many samples and choose the shortest one",
default=4,
),
seed: int = Input(
description="Random seed. Leave blank to randomize the seed", default=None
),
) -> ModelOutput:
"""Run a single prediction on the model"""
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
seed_everything(seed)
segments = self.transcribe_models[whisperx_model].transcribe(
str(orig_audio)
)
state = get_transcribe_state(segments)
whisper_transcript = state["transcript"].strip()
if len(orig_transcript.strip()) == 0:
orig_transcript = whisper_transcript
print(f"The transcript from the Whisper model: {whisper_transcript}")
temp_folder = "exp_dir"
if os.path.exists(temp_folder):
shutil.rmtree(temp_folder)
os.makedirs(temp_folder)
filename = "orig_audio"
audio_fn = str(orig_audio)
info = torchaudio.info(audio_fn)
audio_dur = info.num_frames / info.sample_rate
# hyperparameters for inference
codec_audio_sr = 16000
codec_sr = 50
top_k = 0
silence_tokens = [1388, 1898, 131]
if voicecraft_model == "giga330M_TTSEnhanced.pth":
voicecraft_model = "gigaHalfLibri330M_TTSEnhanced_max16s.pth"
if task == "zero-shot text-to-speech":
assert (
cut_off_sec < audio_dur
), f"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}"
prompt_end_frame = int(cut_off_sec * info.sample_rate)
idx = find_closest_cut_off_word(state["word_bounds"], cut_off_sec)
orig_transcript_until_cutoff_time = " ".join(
[word_bound["word"] for word_bound in state["word_bounds"][: idx + 1]]
)
else:
edit_type = task.split("-")[-1]
orig_span, new_span = get_span(
orig_transcript, target_transcript, edit_type
)
if orig_span[0] > orig_span[1]:
RuntimeError(f"example {audio_fn} failed")
if orig_span[0] == orig_span[1]:
orig_span_save = [orig_span[0]]
else:
orig_span_save = orig_span
if new_span[0] == new_span[1]:
new_span_save = [new_span[0]]
else:
new_span_save = new_span
orig_span_save = ",".join([str(item) for item in orig_span_save])
new_span_save = ",".join([str(item) for item in new_span_save])
start, end = get_mask_interval_from_word_bounds(
state["word_bounds"], orig_span_save, edit_type
)
# span in codec frames
morphed_span = (
max(start - left_margin, 1 / codec_sr),
min(end + right_margin, audio_dur),
) # in seconds
mask_interval = [
[round(morphed_span[0] * codec_sr), round(morphed_span[1] * codec_sr)]
]
mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
decode_config = {
"top_k": top_k,
"top_p": top_p,
"temperature": temperature,
"stop_repetition": stop_repetition,
"kvcache": kvcache,
"codec_audio_sr": codec_audio_sr,
"codec_sr": codec_sr,
"silence_tokens": silence_tokens,
}
if task == "zero-shot text-to-speech":
decode_config["sample_batch_size"] = sample_batch_size
_, gen_audio = inference_one_sample(
self.models[voicecraft_model],
self.ckpt[voicecraft_model]["config"],
self.phn2num[voicecraft_model],
self.text_tokenizer,
self.audio_tokenizer,
audio_fn,
orig_transcript_until_cutoff_time.strip()
+ " "
+ target_transcript.strip(),
self.device,
decode_config,
prompt_end_frame,
)
else:
_, gen_audio = inference_one_sample_editing(
self.models[voicecraft_model],
self.ckpt[voicecraft_model]["config"],
self.phn2num[voicecraft_model],
self.text_tokenizer,
self.audio_tokenizer,
audio_fn,
target_transcript,
mask_interval,
self.device,
decode_config,
)
# save segments for comparison
gen_audio = gen_audio[0].cpu()
out = "/tmp/out.wav"
torchaudio.save(out, gen_audio, codec_audio_sr)
return ModelOutput(
generated_audio=Path(out), whisper_transcript_orig_audio=whisper_transcript
)
def seed_everything(seed):
os.environ["PYTHONHASHSEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def get_transcribe_state(segments):
words_info = [word_info for segment in segments for word_info in segment["words"]]
return {
"transcript": " ".join([segment["text"].strip() for segment in segments]),
"word_bounds": [
{"word": word["word"], "start": word["start"], "end": word["end"]}
for word in words_info
],
}
def find_closest_cut_off_word(word_bounds, cut_off_sec):
min_distance = float("inf")
for i, word_bound in enumerate(word_bounds):
distance = abs(word_bound["start"] - cut_off_sec)
if distance < min_distance:
min_distance = distance
if word_bound["end"] > cut_off_sec:
break
return i
def get_mask_interval_from_word_bounds(word_bounds, word_span_ind, editType):
tmp = word_span_ind.split(",")
s, e = int(tmp[0]), int(tmp[-1])
start = None
for j, item in enumerate(word_bounds):
if j == s:
if editType == "insertion":
start = float(item["end"])
else:
start = float(item["start"])
if j == e:
if editType == "insertion":
end = float(item["start"])
else:
end = float(item["end"])
assert start is not None
break
return (start, end)