Remove Replace from Huggingface

Accidentally ended up in the wrong section, for downloads we do not replace anything only afterwards.
This commit is contained in:
henk717 2021-12-23 17:27:09 +01:00
parent e7aa92cd86
commit 25a6e489c1

View File

@ -890,11 +890,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
print("Model does not exist locally, attempting to download from Huggingface...")
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache/")
with(maybe_use_float16()):
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/")
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache/")
try:
model = AutoModelForCausalLM.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/", **maybe_low_cpu_mem_usage())
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache/", **maybe_low_cpu_mem_usage())
except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/", **maybe_low_cpu_mem_usage())
model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache/", **maybe_low_cpu_mem_usage())
model = model.half()
import shutil
shutil.rmtree("cache/")