Use `vars.model_type` to check for GPT-2 models
This commit is contained in:
parent
54a587d6a3
commit
74f79081d1
|
@ -917,9 +917,9 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||||
# We must disable low_cpu_mem_usage (by setting lowmem to {}) if
|
# We must disable low_cpu_mem_usage (by setting lowmem to {}) if
|
||||||
# using a GPT-2 model because GPT-2 is not compatible with this
|
# using a GPT-2 model because GPT-2 is not compatible with this
|
||||||
# feature yet
|
# feature yet
|
||||||
if("/" not in vars.model and vars.model.lower().startswith("gpt2")):
|
if(vars.model_type == "gpt2"):
|
||||||
lowmem = {}
|
lowmem = {}
|
||||||
|
|
||||||
# Download model from Huggingface if it does not exist, otherwise load locally
|
# Download model from Huggingface if it does not exist, otherwise load locally
|
||||||
if(os.path.isdir(vars.custmodpth)):
|
if(os.path.isdir(vars.custmodpth)):
|
||||||
with(maybe_use_float16()):
|
with(maybe_use_float16()):
|
||||||
|
|
Loading…
Reference in New Issue