mirror of
https://github.com/jasonppy/VoiceCraft.git
synced 2025-06-05 21:49:11 +02:00
hf model download
This commit is contained in:
@@ -92,27 +92,22 @@ def load_models(whisper_backend_name, whisper_model_name, alignment_model_name,
|
||||
transcribe_model = WhisperxModel(whisper_model_name, align_model)
|
||||
|
||||
voicecraft_name = f"{voicecraft_model_name}.pth"
|
||||
ckpt_fn = f"{MODELS_PATH}/{voicecraft_name}"
|
||||
model = voicecraft.VoiceCraftHF.from_pretrained(f"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}")
|
||||
phn2num = model.args.phn2num
|
||||
config = model.args
|
||||
model.to(device)
|
||||
|
||||
encodec_fn = f"{MODELS_PATH}/encodec_4cb2048_giga.th"
|
||||
if not os.path.exists(ckpt_fn):
|
||||
os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/{voicecraft_name}\?download\=true")
|
||||
os.system(f"mv {voicecraft_name}\?download\=true {MODELS_PATH}/{voicecraft_name}")
|
||||
if not os.path.exists(encodec_fn):
|
||||
os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th")
|
||||
os.system(f"mv encodec_4cb2048_giga.th {MODELS_PATH}/encodec_4cb2048_giga.th")
|
||||
|
||||
ckpt = torch.load(ckpt_fn, map_location="cpu")
|
||||
model = voicecraft.VoiceCraft(ckpt["config"])
|
||||
model.load_state_dict(ckpt["model"])
|
||||
model.to(device)
|
||||
model.eval()
|
||||
voicecraft_model = {
|
||||
"ckpt": ckpt,
|
||||
"config": config,
|
||||
"phn2num": phn2num,
|
||||
"model": model,
|
||||
"text_tokenizer": TextTokenizer(backend="espeak"),
|
||||
"audio_tokenizer": AudioTokenizer(signature=encodec_fn)
|
||||
}
|
||||
|
||||
return gr.Accordion()
|
||||
|
||||
|
||||
@@ -254,8 +249,8 @@ def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p,
|
||||
|
||||
prompt_end_frame = int(min(audio_dur, prompt_end_time) * info.sample_rate)
|
||||
_, gen_audio = inference_one_sample(voicecraft_model["model"],
|
||||
voicecraft_model["ckpt"]["config"],
|
||||
voicecraft_model["ckpt"]["phn2num"],
|
||||
voicecraft_model["config"],
|
||||
voicecraft_model["phn2num"],
|
||||
voicecraft_model["text_tokenizer"], voicecraft_model["audio_tokenizer"],
|
||||
audio_path, target_transcript, device, decode_config,
|
||||
prompt_end_frame)
|
||||
@@ -283,8 +278,8 @@ def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p,
|
||||
mask_interval = torch.LongTensor(mask_interval)
|
||||
|
||||
_, gen_audio = inference_one_sample(voicecraft_model["model"],
|
||||
voicecraft_model["ckpt"]["config"],
|
||||
voicecraft_model["ckpt"]["phn2num"],
|
||||
voicecraft_model["config"],
|
||||
voicecraft_model["phn2num"],
|
||||
voicecraft_model["text_tokenizer"], voicecraft_model["audio_tokenizer"],
|
||||
audio_path, target_transcript, mask_interval, device, decode_config)
|
||||
gen_audio = gen_audio[0].cpu()
|
||||
|
Reference in New Issue
Block a user