Use `vars.model_type` to check for GPT-2 models
This commit is contained in:
parent
54a587d6a3
commit
74f79081d1
|
@ -917,7 +917,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||
# We must disable low_cpu_mem_usage (by setting lowmem to {}) if
|
||||
# using a GPT-2 model because GPT-2 is not compatible with this
|
||||
# feature yet
|
||||
if("/" not in vars.model and vars.model.lower().startswith("gpt2")):
|
||||
if(vars.model_type == "gpt2"):
|
||||
lowmem = {}
|
||||
|
||||
# Download model from Huggingface if it does not exist, otherwise load locally
|
||||
|
|
Loading…
Reference in New Issue