Merge branch 'united' into gui-and-scripting

This commit is contained in:
Gnome Ann 2021-12-23 13:02:01 -05:00
commit 924c48a6d7

View File

@ -397,19 +397,28 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
# This code is not just a workaround for below, it is also used to make the behavior consistent with other loading methods - Henk717
if(not vars.model in ["NeoCustom", "GPT2Custom"]):
vars.custmodpth = vars.model
elif(vars.model == "NeoCustom"):
vars.model = os.path.basename(os.path.normpath(vars.custmodpth))
# Get the model_type from the config or assume a model type if it isn't present
from transformers import AutoConfig
try:
model_config = AutoConfig.from_pretrained(vars.custmodpth)
except ValueError as e:
vars.model_type = "not_found"
if(not vars.model_type == "not_found"):
vars.model_type = model_config.model_type
elif(vars.model == "NeoCustom"):
vars.model_type = "gpt_neo"
elif(vars.model == "GPT2Custom"):
vars.model_type = "gpt2"
if(os.path.isdir(vars.custmodpth.replace('/', '_'))):
try:
model_config = AutoConfig.from_pretrained(vars.custmodpth.replace('/', '_'), cache_dir="cache/")
vars.model_type = model_config.model_type
except ValueError as e:
vars.model_type = "not_found"
else:
try:
model_config = AutoConfig.from_pretrained(vars.custmodpth, cache_dir="cache/")
vars.model_type = model_config.model_type
except ValueError as e:
vars.model_type = "not_found"
if(vars.model_type == "not_found" and vars.model == "NeoCustom"):
vars.model_type = "gpt_neo"
elif(vars.model_type == "not_found" and vars.model == "GPT2Custom"):
vars.model_type = "gpt2"
elif(vars.model_type == "not_found"):
print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)")
vars.model_type = "gpt_neo"
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
@ -844,32 +853,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
else:
yield False
# If custom GPT Neo model was chosen
if(vars.model == "NeoCustom"):
model_config = open(vars.custmodpth + "/config.json", "r")
js = json.load(model_config)
with(maybe_use_float16()):
if("model_type" in js):
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache/", **maybe_low_cpu_mem_usage())
else:
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache/", **maybe_low_cpu_mem_usage())
vars.modeldim = get_hidden_size_from_model(model)
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache/")
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
if(vars.hascuda):
if(vars.usegpu):
model = model.half().to(vars.gpu_device)
generator = model.generate
elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel)
device_config(model)
else:
model = model.to('cpu').float()
generator = model.generate
else:
model = model.to('cpu').float()
generator = model.generate
# If custom GPT2 model was chosen
elif(vars.model == "GPT2Custom"):
if(vars.model == "GPT2Custom"):
model_config = open(vars.custmodpth + "/config.json", "r")
js = json.load(model_config)
with(maybe_use_float16()):
@ -883,7 +868,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
else:
model = model.to('cpu').float()
generator = model.generate
# If base HuggingFace model was chosen
# Use the Generic implementation
else:
lowmem = maybe_low_cpu_mem_usage()
# We must disable low_cpu_mem_usage (by setting lowmem to {}) if
@ -893,17 +878,29 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
lowmem = {}
# Download model from Huggingface if it does not exist, otherwise load locally
if(os.path.isdir(vars.model.replace('/', '_'))):
if(os.path.isdir(vars.custmodpth)):
with(maybe_use_float16()):
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache/")
try:
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache/", **lowmem)
except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache/", **lowmem)
elif(os.path.isdir(vars.model.replace('/', '_'))):
with(maybe_use_float16()):
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/")
model = AutoModelForCausalLM.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/", **lowmem)
try:
model = AutoModelForCausalLM.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/", **lowmem)
except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/", **lowmem)
else:
print("Model does not exist locally, attempting to download from Huggingface...")
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache/")
with(maybe_use_float16()):
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache/", **lowmem)
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache/")
try:
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache/", **lowmem)
except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache/", **lowmem)
model = model.half()
import shutil
shutil.rmtree("cache/")
@ -938,15 +935,15 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
else:
from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
else:
# If we're running Colab or OAI, we still need a tokenizer.
if(vars.model == "Colab"):
from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained("EleutherAI/gpt-neo-2.7B")
tokenizer = GPT2TokenizerFast.from_pretrained("EleutherAI/gpt-neo-2.7B", cache_dir="cache/")
elif(vars.model == "OAI"):
from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
# Load the TPU backend if requested
elif(vars.model == "TPUMeshTransformerGPTJ"):
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
@ -1097,7 +1094,7 @@ def lua_decode(tokens):
if("tokenizer" not in globals()):
from transformers import GPT2TokenizerFast
global tokenizer
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
return tokenizer.decode(tokens)
#==================================================================#
@ -1108,7 +1105,7 @@ def lua_encode(string):
if("tokenizer" not in globals()):
from transformers import GPT2TokenizerFast
global tokenizer
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
return tokenizer.encode(string, max_length=int(4e9), truncation=True)
#==================================================================#