Merge pull request #83 from VE-FORBRYDERNE/loadsettings
Load settings earlier to avoid TPU badwords issues
This commit is contained in:
commit
1fc173890e
417
aiserver.py
417
aiserver.py
|
@ -461,6 +461,200 @@ def loadmodelsettings():
|
||||||
if(not vars.gamestarted):
|
if(not vars.gamestarted):
|
||||||
vars.authornotetemplate = vars.setauthornotetemplate
|
vars.authornotetemplate = vars.setauthornotetemplate
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Take settings from vars and write them to client settings file
|
||||||
|
#==================================================================#
|
||||||
|
def savesettings():
|
||||||
|
# Build json to write
|
||||||
|
js = {}
|
||||||
|
js["apikey"] = vars.apikey
|
||||||
|
js["andepth"] = vars.andepth
|
||||||
|
js["temp"] = vars.temp
|
||||||
|
js["top_p"] = vars.top_p
|
||||||
|
js["top_k"] = vars.top_k
|
||||||
|
js["tfs"] = vars.tfs
|
||||||
|
js["rep_pen"] = vars.rep_pen
|
||||||
|
js["rep_pen_slope"] = vars.rep_pen_slope
|
||||||
|
js["rep_pen_range"] = vars.rep_pen_range
|
||||||
|
js["genamt"] = vars.genamt
|
||||||
|
js["max_length"] = vars.max_length
|
||||||
|
js["ikgen"] = vars.ikgen
|
||||||
|
js["formatoptns"] = vars.formatoptns
|
||||||
|
js["numseqs"] = vars.numseqs
|
||||||
|
js["widepth"] = vars.widepth
|
||||||
|
js["useprompt"] = vars.useprompt
|
||||||
|
js["adventure"] = vars.adventure
|
||||||
|
js["chatmode"] = vars.chatmode
|
||||||
|
js["chatname"] = vars.chatname
|
||||||
|
js["dynamicscan"] = vars.dynamicscan
|
||||||
|
js["nopromptgen"] = vars.nopromptgen
|
||||||
|
js["rngpersist"] = vars.rngpersist
|
||||||
|
js["nogenmod"] = vars.nogenmod
|
||||||
|
js["autosave"] = vars.autosave
|
||||||
|
js["welcome"] = vars.welcome
|
||||||
|
js["newlinemode"] = vars.newlinemode
|
||||||
|
|
||||||
|
js["antemplate"] = vars.setauthornotetemplate
|
||||||
|
|
||||||
|
js["userscripts"] = vars.userscripts
|
||||||
|
js["corescript"] = vars.corescript
|
||||||
|
js["softprompt"] = vars.spfilename
|
||||||
|
|
||||||
|
# Write it
|
||||||
|
if not os.path.exists('settings'):
|
||||||
|
os.mkdir('settings')
|
||||||
|
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "w")
|
||||||
|
try:
|
||||||
|
file.write(json.dumps(js, indent=3))
|
||||||
|
finally:
|
||||||
|
file.close()
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Don't save settings unless 2 seconds have passed without modification
|
||||||
|
#==================================================================#
|
||||||
|
@debounce(2)
|
||||||
|
def settingschanged():
|
||||||
|
print("{0}Saving settings!{1}".format(colors.GREEN, colors.END))
|
||||||
|
savesettings()
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Read settings from client file JSON and send to vars
|
||||||
|
#==================================================================#
|
||||||
|
def loadsettings():
|
||||||
|
if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
|
||||||
|
# Read file contents into JSON object
|
||||||
|
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "r")
|
||||||
|
js = json.load(file)
|
||||||
|
|
||||||
|
# Copy file contents to vars
|
||||||
|
if("apikey" in js):
|
||||||
|
vars.apikey = js["apikey"]
|
||||||
|
if("andepth" in js):
|
||||||
|
vars.andepth = js["andepth"]
|
||||||
|
if("temp" in js):
|
||||||
|
vars.temp = js["temp"]
|
||||||
|
if("top_p" in js):
|
||||||
|
vars.top_p = js["top_p"]
|
||||||
|
if("top_k" in js):
|
||||||
|
vars.top_k = js["top_k"]
|
||||||
|
if("tfs" in js):
|
||||||
|
vars.tfs = js["tfs"]
|
||||||
|
if("rep_pen" in js):
|
||||||
|
vars.rep_pen = js["rep_pen"]
|
||||||
|
if("rep_pen_slope" in js):
|
||||||
|
vars.rep_pen_slope = js["rep_pen_slope"]
|
||||||
|
if("rep_pen_range" in js):
|
||||||
|
vars.rep_pen_range = js["rep_pen_range"]
|
||||||
|
if("genamt" in js):
|
||||||
|
vars.genamt = js["genamt"]
|
||||||
|
if("max_length" in js):
|
||||||
|
vars.max_length = js["max_length"]
|
||||||
|
if("ikgen" in js):
|
||||||
|
vars.ikgen = js["ikgen"]
|
||||||
|
if("formatoptns" in js):
|
||||||
|
vars.formatoptns = js["formatoptns"]
|
||||||
|
if("numseqs" in js):
|
||||||
|
vars.numseqs = js["numseqs"]
|
||||||
|
if("widepth" in js):
|
||||||
|
vars.widepth = js["widepth"]
|
||||||
|
if("useprompt" in js):
|
||||||
|
vars.useprompt = js["useprompt"]
|
||||||
|
if("adventure" in js):
|
||||||
|
vars.adventure = js["adventure"]
|
||||||
|
if("chatmode" in js):
|
||||||
|
vars.chatmode = js["chatmode"]
|
||||||
|
if("chatname" in js):
|
||||||
|
vars.chatname = js["chatname"]
|
||||||
|
if("dynamicscan" in js):
|
||||||
|
vars.dynamicscan = js["dynamicscan"]
|
||||||
|
if("nopromptgen" in js):
|
||||||
|
vars.nopromptgen = js["nopromptgen"]
|
||||||
|
if("rngpersist" in js):
|
||||||
|
vars.rngpersist = js["rngpersist"]
|
||||||
|
if("nogenmod" in js):
|
||||||
|
vars.nogenmod = js["nogenmod"]
|
||||||
|
if("autosave" in js):
|
||||||
|
vars.autosave = js["autosave"]
|
||||||
|
if("newlinemode" in js):
|
||||||
|
vars.newlinemode = js["newlinemode"]
|
||||||
|
if("welcome" in js):
|
||||||
|
vars.welcome = js["welcome"]
|
||||||
|
|
||||||
|
if("antemplate" in js):
|
||||||
|
vars.setauthornotetemplate = js["antemplate"]
|
||||||
|
if(not vars.gamestarted):
|
||||||
|
vars.authornotetemplate = vars.setauthornotetemplate
|
||||||
|
|
||||||
|
if("userscripts" in js):
|
||||||
|
vars.userscripts = []
|
||||||
|
for userscript in js["userscripts"]:
|
||||||
|
if type(userscript) is not str:
|
||||||
|
continue
|
||||||
|
userscript = userscript.strip()
|
||||||
|
if len(userscript) != 0 and all(q not in userscript for q in ("..", ":")) and all(userscript[0] not in q for q in ("/", "\\")) and os.path.exists(fileops.uspath(userscript)):
|
||||||
|
vars.userscripts.append(userscript)
|
||||||
|
|
||||||
|
if("corescript" in js and type(js["corescript"]) is str and all(q not in js["corescript"] for q in ("..", ":")) and all(js["corescript"][0] not in q for q in ("/", "\\"))):
|
||||||
|
vars.corescript = js["corescript"]
|
||||||
|
else:
|
||||||
|
vars.corescript = "default.lua"
|
||||||
|
|
||||||
|
file.close()
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Load a soft prompt from a file
|
||||||
|
#==================================================================#
|
||||||
|
def spRequest(filename):
|
||||||
|
vars.spfilename = ""
|
||||||
|
settingschanged()
|
||||||
|
|
||||||
|
if(len(filename) == 0):
|
||||||
|
vars.sp = None
|
||||||
|
vars.sp_length = 0
|
||||||
|
return
|
||||||
|
|
||||||
|
global np
|
||||||
|
if 'np' not in globals():
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
z, version, shape, fortran_order, dtype = fileops.checksp(filename, vars.modeldim)
|
||||||
|
assert isinstance(z, zipfile.ZipFile)
|
||||||
|
with z.open('meta.json') as f:
|
||||||
|
vars.spmeta = json.load(f)
|
||||||
|
z.close()
|
||||||
|
|
||||||
|
with np.load(fileops.sppath(filename), allow_pickle=False) as f:
|
||||||
|
tensor = f['tensor.npy']
|
||||||
|
|
||||||
|
# If the tensor is in bfloat16 format, convert it to float32
|
||||||
|
if(tensor.dtype == 'V2'):
|
||||||
|
tensor.dtype = np.uint16
|
||||||
|
tensor = np.uint32(tensor) << 16
|
||||||
|
tensor.dtype = np.float32
|
||||||
|
|
||||||
|
if(tensor.dtype != np.float16):
|
||||||
|
tensor = np.float32(tensor)
|
||||||
|
assert not np.isinf(tensor).any() and not np.isnan(tensor).any()
|
||||||
|
|
||||||
|
vars.sp_length = tensor.shape[-2]
|
||||||
|
vars.spmeta["n_tokens"] = vars.sp_length
|
||||||
|
|
||||||
|
if(vars.model in ("TPUMeshTransformerGPTJ",)):
|
||||||
|
rows = tensor.shape[0]
|
||||||
|
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
|
||||||
|
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
|
||||||
|
tensor = tensor.reshape(
|
||||||
|
tpu_mtj_backend.params["cores_per_replica"],
|
||||||
|
-1,
|
||||||
|
tpu_mtj_backend.params["d_model"],
|
||||||
|
)
|
||||||
|
vars.sp = tpu_mtj_backend.shard_xmap(np.float32(tensor))
|
||||||
|
else:
|
||||||
|
vars.sp = torch.from_numpy(tensor)
|
||||||
|
|
||||||
|
vars.spfilename = filename
|
||||||
|
settingschanged()
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Startup
|
# Startup
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@ -573,6 +767,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||||
print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)")
|
print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)")
|
||||||
vars.model_type = "gpt_neo"
|
vars.model_type = "gpt_neo"
|
||||||
loadmodelsettings()
|
loadmodelsettings()
|
||||||
|
loadsettings()
|
||||||
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
|
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
|
||||||
vars.hascuda = torch.cuda.is_available()
|
vars.hascuda = torch.cuda.is_available()
|
||||||
vars.bmsupported = vars.model_type in ("gpt_neo", "gptj", "xglm") and not vars.nobreakmodel
|
vars.bmsupported = vars.model_type in ("gpt_neo", "gptj", "xglm") and not vars.nobreakmodel
|
||||||
|
@ -1191,9 +1386,11 @@ else:
|
||||||
if(vars.model == "Colab"):
|
if(vars.model == "Colab"):
|
||||||
from transformers import GPT2TokenizerFast
|
from transformers import GPT2TokenizerFast
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("EleutherAI/gpt-neo-2.7B", cache_dir="cache/")
|
tokenizer = GPT2TokenizerFast.from_pretrained("EleutherAI/gpt-neo-2.7B", cache_dir="cache/")
|
||||||
|
loadsettings()
|
||||||
elif(vars.model == "OAI"):
|
elif(vars.model == "OAI"):
|
||||||
from transformers import GPT2TokenizerFast
|
from transformers import GPT2TokenizerFast
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
|
||||||
|
loadsettings()
|
||||||
# Load the TPU backend if requested
|
# Load the TPU backend if requested
|
||||||
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
||||||
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
|
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||||
|
@ -1206,11 +1403,14 @@ else:
|
||||||
tpu_mtj_backend.compiling_callback = tpumtjgenerate_compiling_callback
|
tpu_mtj_backend.compiling_callback = tpumtjgenerate_compiling_callback
|
||||||
tpu_mtj_backend.stopped_compiling_callback = tpumtjgenerate_stopped_compiling_callback
|
tpu_mtj_backend.stopped_compiling_callback = tpumtjgenerate_stopped_compiling_callback
|
||||||
tpu_mtj_backend.settings_callback = tpumtjgenerate_settings_callback
|
tpu_mtj_backend.settings_callback = tpumtjgenerate_settings_callback
|
||||||
loadmodelsettings()
|
|
||||||
tpu_mtj_backend.load_model(vars.custmodpth, **vars.modelconfig)
|
|
||||||
vars.allowsp = True
|
vars.allowsp = True
|
||||||
|
loadmodelsettings()
|
||||||
|
loadsettings()
|
||||||
|
tpu_mtj_backend.load_model(vars.custmodpth, **vars.modelconfig)
|
||||||
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
|
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
|
||||||
tokenizer = tpu_mtj_backend.tokenizer
|
tokenizer = tpu_mtj_backend.tokenizer
|
||||||
|
else:
|
||||||
|
loadsettings()
|
||||||
|
|
||||||
# Set up Flask routes
|
# Set up Flask routes
|
||||||
@app.route('/')
|
@app.route('/')
|
||||||
|
@ -2302,152 +2502,6 @@ def sendsettings():
|
||||||
if(not frm["id"] in vars.formatoptns):
|
if(not frm["id"] in vars.formatoptns):
|
||||||
vars.formatoptns[frm["id"]] = False;
|
vars.formatoptns[frm["id"]] = False;
|
||||||
|
|
||||||
#==================================================================#
|
|
||||||
# Take settings from vars and write them to client settings file
|
|
||||||
#==================================================================#
|
|
||||||
def savesettings():
|
|
||||||
# Build json to write
|
|
||||||
js = {}
|
|
||||||
js["apikey"] = vars.apikey
|
|
||||||
js["andepth"] = vars.andepth
|
|
||||||
js["temp"] = vars.temp
|
|
||||||
js["top_p"] = vars.top_p
|
|
||||||
js["top_k"] = vars.top_k
|
|
||||||
js["tfs"] = vars.tfs
|
|
||||||
js["rep_pen"] = vars.rep_pen
|
|
||||||
js["rep_pen_slope"] = vars.rep_pen_slope
|
|
||||||
js["rep_pen_range"] = vars.rep_pen_range
|
|
||||||
js["genamt"] = vars.genamt
|
|
||||||
js["max_length"] = vars.max_length
|
|
||||||
js["ikgen"] = vars.ikgen
|
|
||||||
js["formatoptns"] = vars.formatoptns
|
|
||||||
js["numseqs"] = vars.numseqs
|
|
||||||
js["widepth"] = vars.widepth
|
|
||||||
js["useprompt"] = vars.useprompt
|
|
||||||
js["adventure"] = vars.adventure
|
|
||||||
js["chatmode"] = vars.chatmode
|
|
||||||
js["chatname"] = vars.chatname
|
|
||||||
js["dynamicscan"] = vars.dynamicscan
|
|
||||||
js["nopromptgen"] = vars.nopromptgen
|
|
||||||
js["rngpersist"] = vars.rngpersist
|
|
||||||
js["nogenmod"] = vars.nogenmod
|
|
||||||
js["autosave"] = vars.autosave
|
|
||||||
js["welcome"] = vars.welcome
|
|
||||||
js["newlinemode"] = vars.newlinemode
|
|
||||||
|
|
||||||
js["antemplate"] = vars.setauthornotetemplate
|
|
||||||
|
|
||||||
js["userscripts"] = vars.userscripts
|
|
||||||
js["corescript"] = vars.corescript
|
|
||||||
js["softprompt"] = vars.spfilename
|
|
||||||
|
|
||||||
# Write it
|
|
||||||
if not os.path.exists('settings'):
|
|
||||||
os.mkdir('settings')
|
|
||||||
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "w")
|
|
||||||
try:
|
|
||||||
file.write(json.dumps(js, indent=3))
|
|
||||||
finally:
|
|
||||||
file.close()
|
|
||||||
|
|
||||||
#==================================================================#
|
|
||||||
# Read settings from client file JSON and send to vars
|
|
||||||
#==================================================================#
|
|
||||||
def loadsettings():
|
|
||||||
if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
|
|
||||||
# Read file contents into JSON object
|
|
||||||
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "r")
|
|
||||||
js = json.load(file)
|
|
||||||
|
|
||||||
# Copy file contents to vars
|
|
||||||
if("apikey" in js):
|
|
||||||
vars.apikey = js["apikey"]
|
|
||||||
if("andepth" in js):
|
|
||||||
vars.andepth = js["andepth"]
|
|
||||||
if("temp" in js):
|
|
||||||
vars.temp = js["temp"]
|
|
||||||
if("top_p" in js):
|
|
||||||
vars.top_p = js["top_p"]
|
|
||||||
if("top_k" in js):
|
|
||||||
vars.top_k = js["top_k"]
|
|
||||||
if("tfs" in js):
|
|
||||||
vars.tfs = js["tfs"]
|
|
||||||
if("rep_pen" in js):
|
|
||||||
vars.rep_pen = js["rep_pen"]
|
|
||||||
if("rep_pen_slope" in js):
|
|
||||||
vars.rep_pen_slope = js["rep_pen_slope"]
|
|
||||||
if("rep_pen_range" in js):
|
|
||||||
vars.rep_pen_range = js["rep_pen_range"]
|
|
||||||
if("genamt" in js):
|
|
||||||
vars.genamt = js["genamt"]
|
|
||||||
if("max_length" in js):
|
|
||||||
vars.max_length = js["max_length"]
|
|
||||||
if("ikgen" in js):
|
|
||||||
vars.ikgen = js["ikgen"]
|
|
||||||
if("formatoptns" in js):
|
|
||||||
vars.formatoptns = js["formatoptns"]
|
|
||||||
if("numseqs" in js):
|
|
||||||
vars.numseqs = js["numseqs"]
|
|
||||||
if("widepth" in js):
|
|
||||||
vars.widepth = js["widepth"]
|
|
||||||
if("useprompt" in js):
|
|
||||||
vars.useprompt = js["useprompt"]
|
|
||||||
if("adventure" in js):
|
|
||||||
vars.adventure = js["adventure"]
|
|
||||||
if("chatmode" in js):
|
|
||||||
vars.chatmode = js["chatmode"]
|
|
||||||
if("chatname" in js):
|
|
||||||
vars.chatname = js["chatname"]
|
|
||||||
if("dynamicscan" in js):
|
|
||||||
vars.dynamicscan = js["dynamicscan"]
|
|
||||||
if("nopromptgen" in js):
|
|
||||||
vars.nopromptgen = js["nopromptgen"]
|
|
||||||
if("rngpersist" in js):
|
|
||||||
vars.rngpersist = js["rngpersist"]
|
|
||||||
if("nogenmod" in js):
|
|
||||||
vars.nogenmod = js["nogenmod"]
|
|
||||||
if("autosave" in js):
|
|
||||||
vars.autosave = js["autosave"]
|
|
||||||
if("newlinemode" in js):
|
|
||||||
vars.newlinemode = js["newlinemode"]
|
|
||||||
if("welcome" in js):
|
|
||||||
vars.welcome = js["welcome"]
|
|
||||||
|
|
||||||
if("antemplate" in js):
|
|
||||||
vars.setauthornotetemplate = js["antemplate"]
|
|
||||||
if(not vars.gamestarted):
|
|
||||||
vars.authornotetemplate = vars.setauthornotetemplate
|
|
||||||
|
|
||||||
if("userscripts" in js):
|
|
||||||
vars.userscripts = []
|
|
||||||
for userscript in js["userscripts"]:
|
|
||||||
if type(userscript) is not str:
|
|
||||||
continue
|
|
||||||
userscript = userscript.strip()
|
|
||||||
if len(userscript) != 0 and all(q not in userscript for q in ("..", ":")) and all(userscript[0] not in q for q in ("/", "\\")) and os.path.exists(fileops.uspath(userscript)):
|
|
||||||
vars.userscripts.append(userscript)
|
|
||||||
|
|
||||||
if("corescript" in js and type(js["corescript"]) is str and all(q not in js["corescript"] for q in ("..", ":")) and all(js["corescript"][0] not in q for q in ("/", "\\"))):
|
|
||||||
vars.corescript = js["corescript"]
|
|
||||||
else:
|
|
||||||
vars.corescript = "default.lua"
|
|
||||||
|
|
||||||
if(vars.allowsp and "softprompt" in js and type(js["softprompt"]) is str and all(q not in js["softprompt"] for q in ("..", ":")) and (len(js["softprompt"]) == 0 or all(js["softprompt"][0] not in q for q in ("/", "\\")))):
|
|
||||||
spRequest(js["softprompt"])
|
|
||||||
else:
|
|
||||||
vars.spfilename = ""
|
|
||||||
|
|
||||||
file.close()
|
|
||||||
|
|
||||||
|
|
||||||
#==================================================================#
|
|
||||||
# Don't save settings unless 2 seconds have passed without modification
|
|
||||||
#==================================================================#
|
|
||||||
@debounce(2)
|
|
||||||
def settingschanged():
|
|
||||||
print("{0}Saving settings!{1}".format(colors.GREEN, colors.END))
|
|
||||||
savesettings()
|
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Set value of gamesaved
|
# Set value of gamesaved
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@ -4488,60 +4542,6 @@ def loadRequest(loadpath, filename=None):
|
||||||
emit('from_server', {'cmd': 'hidegenseqs', 'data': ''}, broadcast=True)
|
emit('from_server', {'cmd': 'hidegenseqs', 'data': ''}, broadcast=True)
|
||||||
print("{0}Story loaded from {1}!{2}".format(colors.GREEN, filename, colors.END))
|
print("{0}Story loaded from {1}!{2}".format(colors.GREEN, filename, colors.END))
|
||||||
|
|
||||||
#==================================================================#
|
|
||||||
# Load a soft prompt from a file
|
|
||||||
#==================================================================#
|
|
||||||
def spRequest(filename):
|
|
||||||
vars.spfilename = ""
|
|
||||||
settingschanged()
|
|
||||||
|
|
||||||
if(len(filename) == 0):
|
|
||||||
vars.sp = None
|
|
||||||
vars.sp_length = 0
|
|
||||||
return
|
|
||||||
|
|
||||||
global np
|
|
||||||
if 'np' not in globals():
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
z, version, shape, fortran_order, dtype = fileops.checksp(filename, vars.modeldim)
|
|
||||||
assert isinstance(z, zipfile.ZipFile)
|
|
||||||
with z.open('meta.json') as f:
|
|
||||||
vars.spmeta = json.load(f)
|
|
||||||
z.close()
|
|
||||||
|
|
||||||
with np.load(fileops.sppath(filename), allow_pickle=False) as f:
|
|
||||||
tensor = f['tensor.npy']
|
|
||||||
|
|
||||||
# If the tensor is in bfloat16 format, convert it to float32
|
|
||||||
if(tensor.dtype == 'V2'):
|
|
||||||
tensor.dtype = np.uint16
|
|
||||||
tensor = np.uint32(tensor) << 16
|
|
||||||
tensor.dtype = np.float32
|
|
||||||
|
|
||||||
if(tensor.dtype != np.float16):
|
|
||||||
tensor = np.float32(tensor)
|
|
||||||
assert not np.isinf(tensor).any() and not np.isnan(tensor).any()
|
|
||||||
|
|
||||||
vars.sp_length = tensor.shape[-2]
|
|
||||||
vars.spmeta["n_tokens"] = vars.sp_length
|
|
||||||
|
|
||||||
if(vars.model in ("TPUMeshTransformerGPTJ",)):
|
|
||||||
rows = tensor.shape[0]
|
|
||||||
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
|
|
||||||
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
|
|
||||||
tensor = tensor.reshape(
|
|
||||||
tpu_mtj_backend.params["cores_per_replica"],
|
|
||||||
-1,
|
|
||||||
tpu_mtj_backend.params["d_model"],
|
|
||||||
)
|
|
||||||
vars.sp = tpu_mtj_backend.shard_xmap(np.float32(tensor))
|
|
||||||
else:
|
|
||||||
vars.sp = torch.from_numpy(tensor)
|
|
||||||
|
|
||||||
vars.spfilename = filename
|
|
||||||
settingschanged()
|
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Import an AIDungon game exported with Mimi's tool
|
# Import an AIDungon game exported with Mimi's tool
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@ -4886,9 +4886,6 @@ def randomGameRequest(topic, memory=""):
|
||||||
vars.memory = memory
|
vars.memory = memory
|
||||||
emit('from_server', {'cmd': 'setmemory', 'data': vars.memory}, broadcast=True)
|
emit('from_server', {'cmd': 'setmemory', 'data': vars.memory}, broadcast=True)
|
||||||
|
|
||||||
# Load desired settings from both the model and the users config file
|
|
||||||
loadsettings()
|
|
||||||
|
|
||||||
# Prevent tokenizer from taking extra time the first time it's used
|
# Prevent tokenizer from taking extra time the first time it's used
|
||||||
def __preempt_tokenizer():
|
def __preempt_tokenizer():
|
||||||
if("tokenizer" not in globals()):
|
if("tokenizer" not in globals()):
|
||||||
|
@ -4897,6 +4894,16 @@ def __preempt_tokenizer():
|
||||||
tokenizer.encode(utils.encodenewlines("eunoia"))
|
tokenizer.encode(utils.encodenewlines("eunoia"))
|
||||||
threading.Thread(target=__preempt_tokenizer).start()
|
threading.Thread(target=__preempt_tokenizer).start()
|
||||||
|
|
||||||
|
# Load soft prompt specified by the settings file, if applicable
|
||||||
|
if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
|
||||||
|
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "r")
|
||||||
|
js = json.load(file)
|
||||||
|
if(vars.allowsp and "softprompt" in js and type(js["softprompt"]) is str and all(q not in js["softprompt"] for q in ("..", ":")) and (len(js["softprompt"]) == 0 or all(js["softprompt"][0] not in q for q in ("/", "\\")))):
|
||||||
|
spRequest(js["softprompt"])
|
||||||
|
else:
|
||||||
|
vars.spfilename = ""
|
||||||
|
file.close()
|
||||||
|
|
||||||
# Precompile TPU backend if required
|
# Precompile TPU backend if required
|
||||||
if(vars.model in ("TPUMeshTransformerGPTJ",)):
|
if(vars.model in ("TPUMeshTransformerGPTJ",)):
|
||||||
soft_tokens = tpumtjgetsofttokens()
|
soft_tokens = tpumtjgetsofttokens()
|
||||||
|
|
Loading…
Reference in New Issue