change default inference from top-p to top-k sampling, massive performance gain

This commit is contained in:
pyp_l40
2025-03-15 18:16:27 -05:00
parent 013a21c70d
commit 7121981bb4
6 changed files with 12 additions and 10 deletions

View File

@@ -21,6 +21,8 @@ When you are inside the docker image or you have installed all dependencies, Che
If you want to do model development such as training/finetuning, I recommend following [envrionment setup](#environment-setup) and [training](#training).
## News
:star: 03/15/2025: change inference sampling from topp=1 to topk=40 massively improve editing and TTS performance
:star: 04/22/2024: 330M/830M TTS Enhanced Models are up [here](https://huggingface.co/pyp1), load them through [`gradio_app.py`](./gradio_app.py) or [`inference_tts.ipynb`](./inference_tts.ipynb)! Replicate demo is up, major thanks to [@chenxwh](https://github.com/chenxwh)!
:star: 04/11/2024: VoiceCraft Gradio is now available on HuggingFace Spaces [here](https://huggingface.co/spaces/pyp1/VoiceCraft_gradio)! Major thanks to [@zuev-stepan](https://github.com/zuev-stepan), [@Sewlell](https://github.com/Sewlell), [@pgsoar](https://github.com/pgosar) [@Ph0rk0z](https://github.com/Ph0rk0z).

View File

@@ -512,9 +512,9 @@ def get_app():
info="set to 0 to use less VRAM, but with slower inference")
left_margin = gr.Number(label="left_margin", value=0.08, info="margin to the left of the editing segment")
right_margin = gr.Number(label="right_margin", value=0.08, info="margin to the right of the editing segment")
top_p = gr.Number(label="top_p", value=0.9, info="0.9 is a good value, 0.8 is also good")
top_p = gr.Number(label="top_p", value=1, info="do not do topp sampling therefore set it to 1")
temperature = gr.Number(label="temperature", value=1, info="haven't try other values, do not recommend to change")
top_k = gr.Number(label="top_k", value=0, info="0 means we don't use topk sampling, because we use topp sampling")
top_k = gr.Number(label="top_k", value=40, info="40 is a good default, can also try 20, 30")
codec_audio_sr = gr.Number(label="codec_audio_sr", value=16000, info='encodec specific, Do not change')
codec_sr = gr.Number(label="codec_sr", value=50, info='encodec specific, Do not change')
silence_tokens = gr.Textbox(label="silence tokens", value="[1388,1898,131]", info="encodec specific, do not change")

View File

@@ -66,8 +66,8 @@
"right_margin = 0.08\n",
"codec_audio_sr = 16000\n",
"codec_sr = 50\n",
"top_k = 0\n",
"top_p = 0.8\n",
"top_k = 40\n",
"top_p = 1\n",
"temperature = 1\n",
"kvcache = 0\n",
"# adjust the below three arguments if the generation is not as good\n",

View File

@@ -157,8 +157,8 @@
"# hyperparameters for inference\n",
"codec_audio_sr = 16000\n",
"codec_sr = 50\n",
"top_k = 0\n",
"top_p = 0.9 # can also try 0.8, but 0.9 seems to work better\n",
"top_k = 40 # can also try 20, 30, 50\n",
"top_p = 1 # 1 means do not do top-p sampling\n",
"temperature = 1\n",
"silence_tokens=[1388,1898,131]\n",
"kvcache = 1 # NOTE if OOM, change this to 0, or try the 330M model\n",

View File

@@ -25,8 +25,8 @@ def get_args():
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--codec_audio_sr", type=int, default=16000, help='the sample rate of audio that the codec is trained for')
parser.add_argument("--codec_sr", type=int, default=50, help='the sample rate of the codec codes')
parser.add_argument("--top_k", type=int, default=0, help="sampling param")
parser.add_argument("--top_p", type=float, default=0.8, help="sampling param")
parser.add_argument("--top_k", type=int, default=40, help="sampling param")
parser.add_argument("--top_p", type=float, default=1, help="sampling param")
parser.add_argument("--temperature", type=float, default=1.0, help="sampling param")
parser.add_argument("--output_dir", type=str, default=None)
parser.add_argument("--device", type=str, default="cuda")

View File

@@ -184,7 +184,7 @@ class Predictor(BasePredictor):
),
top_p: float = Input(
description="Default value for TTS is 0.9, and 0.8 for speech editing",
default=0.9,
default=1,
),
stop_repetition: int = Input(
default=3,
@@ -234,7 +234,7 @@ class Predictor(BasePredictor):
# hyperparameters for inference
codec_audio_sr = 16000
codec_sr = 50
top_k = 0
top_k = 40
silence_tokens = [1388, 1898, 131]
if voicecraft_model == "giga330M_TTSEnhanced.pth":