Path loading improvements

This fixes a few scenario's of my commit yesterday, models that have a / are now first loaded from the corrected directory if it exists before we fall back to its original name to make sure it loads the config from the correct location. Cache dir fixes and a improved routine for the path loaded models that mimics the NeoCustom option fixing models that have no model_type specified. Because GPT2 doesn't work well with this option and should exclusively be used with the GPT2Custom and GPT-J models should have a model_type we assume its a Neo model when not specified.
This commit is contained in:
henk717 2021-12-23 14:40:35 +01:00
parent a2d8347939
commit be351e384d
1 changed files with 25 additions and 17 deletions

View File

@ -396,17 +396,23 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
vars.custmodpth = vars.model vars.custmodpth = vars.model
# Get the model_type from the config or assume a model type if it isn't present # Get the model_type from the config or assume a model type if it isn't present
from transformers import AutoConfig from transformers import AutoConfig
try: if(os.path.isdir(vars.custmodpth.replace('/', '_'))):
model_config = AutoConfig.from_pretrained(vars.custmodpth) try:
except ValueError as e: model_config = AutoConfig.from_pretrained(vars.custmodpth.replace('/', '_'), cache_dir="cache/")
vars.model_type = "not_found" vars.model_type = model_config.model_type
if(not vars.model_type == "not_found"): except ValueError as e:
vars.model_type = model_config.model_type vars.model_type = "not_found"
elif(vars.model == "NeoCustom"):
vars.model_type = "gpt_neo"
elif(vars.model == "GPT2Custom"):
vars.model_type = "gpt2"
else: else:
try:
model_config = AutoConfig.from_pretrained(vars.custmodpth, cache_dir="cache/")
vars.model_type = model_config.model_type
except ValueError as e:
vars.model_type = "not_found"
if(vars.model_type == "not_found" and vars.model == "NeoCustom"):
vars.model_type = "gpt_neo"
elif(vars.model_type == "not_found" and vars.model == "GPT2Custom"):
vars.model_type = "gpt2"
elif(vars.model_type == "not_found"):
print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)") print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)")
vars.model_type = "gpt_neo" vars.model_type = "gpt_neo"
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="") print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
@ -891,8 +897,10 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
if(os.path.isdir(vars.model.replace('/', '_'))): if(os.path.isdir(vars.model.replace('/', '_'))):
with(maybe_use_float16()): with(maybe_use_float16()):
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/") tokenizer = GPT2TokenizerFast.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/")
model = AutoModelForCausalLM.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/", **lowmem) try:
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth.replace('/', '_'), cache_dir="cache/", **maybe_low_cpu_mem_usage())
except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth.replace('/', '_'), cache_dir="cache/", **maybe_low_cpu_mem_usage())
else: else:
print("Model does not exist locally, attempting to download from Huggingface...") print("Model does not exist locally, attempting to download from Huggingface...")
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache/") tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache/")
@ -932,15 +940,15 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
else: else:
from transformers import GPT2TokenizerFast from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
else: else:
# If we're running Colab or OAI, we still need a tokenizer. # If we're running Colab or OAI, we still need a tokenizer.
if(vars.model == "Colab"): if(vars.model == "Colab"):
from transformers import GPT2TokenizerFast from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained("EleutherAI/gpt-neo-2.7B") tokenizer = GPT2TokenizerFast.from_pretrained("EleutherAI/gpt-neo-2.7B", cache_dir="cache/")
elif(vars.model == "OAI"): elif(vars.model == "OAI"):
from transformers import GPT2TokenizerFast from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
# Load the TPU backend if requested # Load the TPU backend if requested
elif(vars.model == "TPUMeshTransformerGPTJ"): elif(vars.model == "TPUMeshTransformerGPTJ"):
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END)) print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
@ -1088,7 +1096,7 @@ def lua_decode(tokens):
if("tokenizer" not in globals()): if("tokenizer" not in globals()):
from transformers import GPT2TokenizerFast from transformers import GPT2TokenizerFast
global tokenizer global tokenizer
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
return tokenizer.decode(tokens) return tokenizer.decode(tokens)
#==================================================================# #==================================================================#
@ -1099,7 +1107,7 @@ def lua_encode(string):
if("tokenizer" not in globals()): if("tokenizer" not in globals()):
from transformers import GPT2TokenizerFast from transformers import GPT2TokenizerFast
global tokenizer global tokenizer
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
return tokenizer.encode(string, max_length=int(4e9), truncation=True) return tokenizer.encode(string, max_length=int(4e9), truncation=True)
#==================================================================# #==================================================================#