Merge pull request #99 from VE-FORBRYDERNE/model-patch

Model loading fixes
This commit is contained in:
henk717 2022-03-13 11:10:15 +01:00 committed by GitHub
commit ccadeabbde
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 38 additions and 27 deletions

View File

@ -10,6 +10,8 @@ import eventlet
eventlet.monkey_patch(all=True, thread=False)
import os
os.system("")
__file__ = os.path.dirname(os.path.realpath(__file__))
os.chdir(__file__)
os.environ['EVENTLET_THREADPOOL_SIZE'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
from eventlet import tpool
@ -1024,7 +1026,7 @@ log.setLevel(logging.ERROR)
print("{0}Initializing Flask... {1}".format(colors.PURPLE, colors.END), end="")
from flask import Flask, render_template, Response, request, copy_current_request_context
from flask_socketio import SocketIO, emit
app = Flask(__name__)
app = Flask(__name__, root_path=os.getcwd())
app.config['SECRET KEY'] = 'secret!'
socketio = SocketIO(app, async_method="eventlet")
print("{0}OK!{1}".format(colors.GREEN, colors.END))
@ -1101,7 +1103,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
return lazy_load_callback
lazy_load_config_path = os.path.join(path.dirname(path.realpath(__file__)), "maps", vars.model_type + ".json")
lazy_load_config_path = os.path.join("maps", vars.model_type + ".json")
if(vars.lazy_load and "model_config" in globals() and os.path.isfile(lazy_load_config_path)):
with open(lazy_load_config_path) as f:
lazy_load_spec = json.load(f)
@ -2193,16 +2195,16 @@ vars.lua_state = lupa.LuaRuntime(unpack_returned_tuples=True)
# Load bridge.lua
bridged = {
"corescript_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "cores"),
"userscript_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "userscripts"),
"config_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "userscripts"),
"lib_paths": vars.lua_state.table(os.path.join(os.path.dirname(os.path.realpath(__file__)), "lualibs"), os.path.join(os.path.dirname(os.path.realpath(__file__)), "extern", "lualibs")),
"corescript_path": "cores",
"userscript_path": "userscripts",
"config_path": "userscripts",
"lib_paths": vars.lua_state.table("lualibs", os.path.join("extern", "lualibs")),
"vars": vars,
}
for kwarg in _bridged:
bridged[kwarg] = _bridged[kwarg]
try:
vars.lua_kobold, vars.lua_koboldcore, vars.lua_koboldbridge = vars.lua_state.globals().dofile(os.path.join(os.path.dirname(os.path.realpath(__file__)), "bridge.lua"))(
vars.lua_kobold, vars.lua_koboldcore, vars.lua_koboldbridge = vars.lua_state.globals().dofile("bridge.lua")(
vars.lua_state.globals().python,
bridged,
)

View File

@ -65,30 +65,30 @@ def getdirpath(dir, title):
# Returns the path (as a string) to the given story by its name
#==================================================================#
def storypath(name):
return path.join(path.dirname(path.realpath(__file__)), "stories", name + ".json")
return path.join("stories", name + ".json")
#==================================================================#
# Returns the path (as a string) to the given soft prompt by its filename
#==================================================================#
def sppath(filename):
return path.join(path.dirname(path.realpath(__file__)), "softprompts", filename)
return path.join("softprompts", filename)
#==================================================================#
# Returns the path (as a string) to the given username by its filename
#==================================================================#
def uspath(filename):
return path.join(path.dirname(path.realpath(__file__)), "userscripts", filename)
return path.join("userscripts", filename)
#==================================================================#
# Returns an array of dicts containing story files in /stories
#==================================================================#
def getstoryfiles():
list = []
for file in listdir(path.dirname(path.realpath(__file__))+"/stories"):
for file in listdir("stories"):
if file.endswith(".json"):
ob = {}
ob["name"] = file.replace(".json", "")
f = open(path.dirname(path.realpath(__file__))+"/stories/"+file, "r")
f = open("stories/"+file, "r")
try:
js = json.load(f)
except:
@ -112,7 +112,7 @@ def checksp(filename: str, model_dimension: int) -> Tuple[Union[zipfile.ZipFile,
if 'np' not in globals():
import numpy as np
try:
z = zipfile.ZipFile(path.dirname(path.realpath(__file__))+"/softprompts/"+filename)
z = zipfile.ZipFile("softprompts/"+filename)
with z.open('tensor.npy') as f:
# Read only the header of the npy file, for efficiency reasons
version: Tuple[int, int] = np.lib.format.read_magic(f)
@ -140,8 +140,8 @@ def checksp(filename: str, model_dimension: int) -> Tuple[Union[zipfile.ZipFile,
#==================================================================#
def getspfiles(model_dimension: int):
lst = []
os.makedirs(path.dirname(path.realpath(__file__))+"/softprompts", exist_ok=True)
for file in listdir(path.dirname(path.realpath(__file__))+"/softprompts"):
os.makedirs("softprompts", exist_ok=True)
for file in listdir("softprompts"):
if not file.endswith(".zip"):
continue
z, version, shape, fortran_order, dtype = checksp(file, model_dimension)
@ -174,8 +174,8 @@ def getspfiles(model_dimension: int):
#==================================================================#
def getusfiles(long_desc=False):
lst = []
os.makedirs(path.dirname(path.realpath(__file__))+"/userscripts", exist_ok=True)
for file in listdir(path.dirname(path.realpath(__file__))+"/userscripts"):
os.makedirs("userscripts", exist_ok=True)
for file in listdir("userscripts"):
if file.endswith(".lua"):
ob = {}
ob["filename"] = file

View File

@ -822,7 +822,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
# Try to convert HF config.json to MTJ config
if hf_checkpoint:
spec_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "maps", vars.model_type + ".json")
spec_path = os.path.join("maps", vars.model_type + ".json")
if not os.path.isfile(spec_path):
raise NotImplementedError(f"Unsupported model type {repr(vars.model_type)}")
with open(spec_path) as f:
@ -1035,29 +1035,38 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
if(os.path.isdir(vars.custmodpth)):
try:
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache")
except ValueError as e:
except Exception as e:
try:
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache")
except Exception as e:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache")
except ValueError as e:
except Exception as e:
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache")
elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))):
try:
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
except ValueError as e:
except Exception as e:
try:
tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
except Exception as e:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
except ValueError as e:
except Exception as e:
model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
else:
try:
tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache")
except ValueError as e:
except Exception as e:
try:
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache")
except Exception as e:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache")
except ValueError as e:
except Exception as e:
model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache")
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))