Path loading improvements
This fixes a few scenario's of my commit yesterday, models that have a / are now first loaded from the corrected directory if it exists before we fall back to its original name to make sure it loads the config from the correct location. Cache dir fixes and a improved routine for the path loaded models that mimics the NeoCustom option fixing models that have no model_type specified. Because GPT2 doesn't work well with this option and should exclusively be used with the GPT2Custom and GPT-J models should have a model_type we assume its a Neo model when not specified.
This commit is contained in:
parent
a2d8347939
commit
be351e384d
36
aiserver.py
36
aiserver.py
|
@ -396,17 +396,23 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||||
vars.custmodpth = vars.model
|
vars.custmodpth = vars.model
|
||||||
# Get the model_type from the config or assume a model type if it isn't present
|
# Get the model_type from the config or assume a model type if it isn't present
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
if(os.path.isdir(vars.custmodpth.replace('/', '_'))):
|
||||||
try:
|
try:
|
||||||
model_config = AutoConfig.from_pretrained(vars.custmodpth)
|
model_config = AutoConfig.from_pretrained(vars.custmodpth.replace('/', '_'), cache_dir="cache/")
|
||||||
|
vars.model_type = model_config.model_type
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
vars.model_type = "not_found"
|
vars.model_type = "not_found"
|
||||||
if(not vars.model_type == "not_found"):
|
|
||||||
vars.model_type = model_config.model_type
|
|
||||||
elif(vars.model == "NeoCustom"):
|
|
||||||
vars.model_type = "gpt_neo"
|
|
||||||
elif(vars.model == "GPT2Custom"):
|
|
||||||
vars.model_type = "gpt2"
|
|
||||||
else:
|
else:
|
||||||
|
try:
|
||||||
|
model_config = AutoConfig.from_pretrained(vars.custmodpth, cache_dir="cache/")
|
||||||
|
vars.model_type = model_config.model_type
|
||||||
|
except ValueError as e:
|
||||||
|
vars.model_type = "not_found"
|
||||||
|
if(vars.model_type == "not_found" and vars.model == "NeoCustom"):
|
||||||
|
vars.model_type = "gpt_neo"
|
||||||
|
elif(vars.model_type == "not_found" and vars.model == "GPT2Custom"):
|
||||||
|
vars.model_type = "gpt2"
|
||||||
|
elif(vars.model_type == "not_found"):
|
||||||
print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)")
|
print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)")
|
||||||
vars.model_type = "gpt_neo"
|
vars.model_type = "gpt_neo"
|
||||||
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
|
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
|
||||||
|
@ -891,8 +897,10 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||||
if(os.path.isdir(vars.model.replace('/', '_'))):
|
if(os.path.isdir(vars.model.replace('/', '_'))):
|
||||||
with(maybe_use_float16()):
|
with(maybe_use_float16()):
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/")
|
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/")
|
||||||
model = AutoModelForCausalLM.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/", **lowmem)
|
try:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth.replace('/', '_'), cache_dir="cache/", **maybe_low_cpu_mem_usage())
|
||||||
|
except ValueError as e:
|
||||||
|
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth.replace('/', '_'), cache_dir="cache/", **maybe_low_cpu_mem_usage())
|
||||||
else:
|
else:
|
||||||
print("Model does not exist locally, attempting to download from Huggingface...")
|
print("Model does not exist locally, attempting to download from Huggingface...")
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache/")
|
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache/")
|
||||||
|
@ -932,15 +940,15 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||||
|
|
||||||
else:
|
else:
|
||||||
from transformers import GPT2TokenizerFast
|
from transformers import GPT2TokenizerFast
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
|
||||||
else:
|
else:
|
||||||
# If we're running Colab or OAI, we still need a tokenizer.
|
# If we're running Colab or OAI, we still need a tokenizer.
|
||||||
if(vars.model == "Colab"):
|
if(vars.model == "Colab"):
|
||||||
from transformers import GPT2TokenizerFast
|
from transformers import GPT2TokenizerFast
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("EleutherAI/gpt-neo-2.7B")
|
tokenizer = GPT2TokenizerFast.from_pretrained("EleutherAI/gpt-neo-2.7B", cache_dir="cache/")
|
||||||
elif(vars.model == "OAI"):
|
elif(vars.model == "OAI"):
|
||||||
from transformers import GPT2TokenizerFast
|
from transformers import GPT2TokenizerFast
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
|
||||||
# Load the TPU backend if requested
|
# Load the TPU backend if requested
|
||||||
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
||||||
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
|
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||||
|
@ -1088,7 +1096,7 @@ def lua_decode(tokens):
|
||||||
if("tokenizer" not in globals()):
|
if("tokenizer" not in globals()):
|
||||||
from transformers import GPT2TokenizerFast
|
from transformers import GPT2TokenizerFast
|
||||||
global tokenizer
|
global tokenizer
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
|
||||||
return tokenizer.decode(tokens)
|
return tokenizer.decode(tokens)
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@ -1099,7 +1107,7 @@ def lua_encode(string):
|
||||||
if("tokenizer" not in globals()):
|
if("tokenizer" not in globals()):
|
||||||
from transformers import GPT2TokenizerFast
|
from transformers import GPT2TokenizerFast
|
||||||
global tokenizer
|
global tokenizer
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
|
||||||
return tokenizer.encode(string, max_length=int(4e9), truncation=True)
|
return tokenizer.encode(string, max_length=int(4e9), truncation=True)
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
|
Loading…
Reference in New Issue