Use `vars.model_type` to check for GPT-2 models

This commit is contained in:
Gnome Ann 2022-01-17 13:13:54 -05:00
parent 54a587d6a3
commit 74f79081d1
1 changed files with 2 additions and 2 deletions

View File

@ -917,9 +917,9 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
# We must disable low_cpu_mem_usage (by setting lowmem to {}) if
# using a GPT-2 model because GPT-2 is not compatible with this
# feature yet
if("/" not in vars.model and vars.model.lower().startswith("gpt2")):
if(vars.model_type == "gpt2"):
lowmem = {}
# Download model from Huggingface if it does not exist, otherwise load locally
if(os.path.isdir(vars.custmodpth)):
with(maybe_use_float16()):