hf model download

This commit is contained in:
jason-on-salt-a40
2024-04-13 15:11:15 -07:00
parent 3a8d5f4aab
commit 57079c44b6
7 changed files with 73 additions and 64 deletions

View File

@@ -92,27 +92,22 @@ def load_models(whisper_backend_name, whisper_model_name, alignment_model_name,
transcribe_model = WhisperxModel(whisper_model_name, align_model)
voicecraft_name = f"{voicecraft_model_name}.pth"
ckpt_fn = f"{MODELS_PATH}/{voicecraft_name}"
model = voicecraft.VoiceCraftHF.from_pretrained(f"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}")
phn2num = model.args.phn2num
config = model.args
model.to(device)
encodec_fn = f"{MODELS_PATH}/encodec_4cb2048_giga.th"
if not os.path.exists(ckpt_fn):
os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/{voicecraft_name}\?download\=true")
os.system(f"mv {voicecraft_name}\?download\=true {MODELS_PATH}/{voicecraft_name}")
if not os.path.exists(encodec_fn):
os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th")
os.system(f"mv encodec_4cb2048_giga.th {MODELS_PATH}/encodec_4cb2048_giga.th")
ckpt = torch.load(ckpt_fn, map_location="cpu")
model = voicecraft.VoiceCraft(ckpt["config"])
model.load_state_dict(ckpt["model"])
model.to(device)
model.eval()
voicecraft_model = {
"ckpt": ckpt,
"config": config,
"phn2num": phn2num,
"model": model,
"text_tokenizer": TextTokenizer(backend="espeak"),
"audio_tokenizer": AudioTokenizer(signature=encodec_fn)
}
return gr.Accordion()
@@ -254,8 +249,8 @@ def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p,
prompt_end_frame = int(min(audio_dur, prompt_end_time) * info.sample_rate)
_, gen_audio = inference_one_sample(voicecraft_model["model"],
voicecraft_model["ckpt"]["config"],
voicecraft_model["ckpt"]["phn2num"],
voicecraft_model["config"],
voicecraft_model["phn2num"],
voicecraft_model["text_tokenizer"], voicecraft_model["audio_tokenizer"],
audio_path, target_transcript, device, decode_config,
prompt_end_frame)
@@ -283,8 +278,8 @@ def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p,
mask_interval = torch.LongTensor(mask_interval)
_, gen_audio = inference_one_sample(voicecraft_model["model"],
voicecraft_model["ckpt"]["config"],
voicecraft_model["ckpt"]["phn2num"],
voicecraft_model["config"],
voicecraft_model["phn2num"],
voicecraft_model["text_tokenizer"], voicecraft_model["audio_tokenizer"],
audio_path, target_transcript, mask_interval, device, decode_config)
gen_audio = gen_audio[0].cpu()