fix whisperx loading issue, update generation instruction

This commit is contained in:
Puyuan Peng 2024-04-04 20:31:07 -05:00 committed by GitHub
parent 97b1f51947
commit 0d19fa5d03
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 20 additions and 21 deletions

View File

@ -1,3 +1,6 @@
import os
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0" # for local use
import gradio as gr
import torch
import torchaudio
@ -6,7 +9,6 @@ from data.tokenizer import (
TextTokenizer,
)
from models import voicecraft
import os
import io
import numpy as np
import random
@ -64,7 +66,7 @@ class WhisperModel:
class WhisperxModel:
def __init__(self, model_name, align_model: WhisperxAlignModel):
from whisperx import load_model
self.model = load_model(model_name, device, asr_options={"suppress_numerals": True})
self.model = load_model(model_name, device, asr_options={"suppress_numerals": True, "max_new_tokens": None, "clip_timestamps": None, "hallucination_silence_threshold": None})
self.align_model = align_model
def transcribe(self, audio_path):
@ -75,9 +77,6 @@ class WhisperxModel:
def load_models(whisper_backend_name, whisper_model_name, alignment_model_name, voicecraft_model_name):
global transcribe_model, align_model, voicecraft_model
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
if alignment_model_name is not None:
align_model = WhisperxAlignModel()
@ -443,7 +442,7 @@ with gr.Blocks() as app:
with gr.Row():
with gr.Column(scale=2):
input_audio = gr.Audio(value="./demo/84_121550_000074_000000.wav", label="Input Audio", type="filepath")
input_audio = gr.Audio(value="./demo/84_121550_000074_000000.wav", label="Input Audio", type="filepath", interactive=True)
with gr.Group():
original_transcript = gr.Textbox(label="Original transcript", lines=5, value=demo_original_transcript,
info="Use whisper model to get the transcript. Fix and align it if necessary.")
@ -496,22 +495,22 @@ with gr.Blocks() as app:
rerun_btn = gr.Button(value="Rerun")
with gr.Row():
with gr.Accordion("VoiceCraft config", open=False):
seed = gr.Number(label="seed", value=-1, precision=0)
left_margin = gr.Number(label="left_margin", value=0.08)
right_margin = gr.Number(label="right_margin", value=0.08)
codec_audio_sr = gr.Number(label="codec_audio_sr", value=16000)
codec_sr = gr.Number(label="codec_sr", value=50)
top_k = gr.Number(label="top_k", value=0)
top_p = gr.Number(label="top_p", value=0.8)
temperature = gr.Number(label="temperature", value=1)
stop_repetition = gr.Radio(label="stop_repetition", choices=[-1, 1, 2, 3], value=3,
info="if there are long silence in the generated audio, reduce the stop_repetition to 3, 2 or even 1, -1 = disabled")
sample_batch_size = gr.Number(label="sample_batch_size", value=4, precision=0,
info="generate this many samples and choose the shortest one")
with gr.Accordion("Generation Parameters - change these if you are unhappy with the generation", open=False):
stop_repetition = gr.Radio(label="stop_repetition", choices=[-1, 1, 2, 3, 4], value=3,
info="if there are long silence in the generated audio, reduce the stop_repetition to 2 or 1. -1 = disabled")
sample_batch_size = gr.Number(label="speech rate", value=4, precision=0,
info="The higher the number, the faster the output will be. Under the hood, the model will generate this many samples and choose the shortest one")
seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)")
kvcache = gr.Radio(label="kvcache", choices=[0, 1], value=1,
info="set to 0 to use less VRAM, but with slower inference")
silence_tokens = gr.Textbox(label="silence tokens", value="[1388,1898,131]")
left_margin = gr.Number(label="left_margin", value=0.08, info="margin to the left of the editing segment")
right_margin = gr.Number(label="right_margin", value=0.08, info="margin to the right of the editing segment")
top_p = gr.Number(label="top_p", value=0.8, info="0.8 is a good value, 0.9 is also good")
temperature = gr.Number(label="temperature", value=1, info="haven't try other values, do not recommend to change")
top_k = gr.Number(label="top_k", value=0, info="0 means we don't use topk sampling, because we use topp sampling")
codec_audio_sr = gr.Number(label="codec_audio_sr", value=16000, info='encodec specific, Do not change')
codec_sr = gr.Number(label="codec_sr", value=50, info='encodec specific, Do not change')
silence_tokens = gr.Textbox(label="silence tokens", value="[1388,1898,131]", info="encodec specific, do not change")
audio_tensors = gr.State()