Need to set `vars.allowsp` to True before calling `loadsettings()`

This commit is contained in:
Gnome Ann 2022-02-23 21:09:31 -05:00
parent c45ba497c9
commit 8120e4dfa2
1 changed files with 111 additions and 111 deletions

View File

@ -461,6 +461,62 @@ def loadmodelsettings():
if(not vars.gamestarted):
vars.authornotetemplate = vars.setauthornotetemplate
#==================================================================#
# Take settings from vars and write them to client settings file
#==================================================================#
def savesettings():
# Build json to write
js = {}
js["apikey"] = vars.apikey
js["andepth"] = vars.andepth
js["temp"] = vars.temp
js["top_p"] = vars.top_p
js["top_k"] = vars.top_k
js["tfs"] = vars.tfs
js["rep_pen"] = vars.rep_pen
js["rep_pen_slope"] = vars.rep_pen_slope
js["rep_pen_range"] = vars.rep_pen_range
js["genamt"] = vars.genamt
js["max_length"] = vars.max_length
js["ikgen"] = vars.ikgen
js["formatoptns"] = vars.formatoptns
js["numseqs"] = vars.numseqs
js["widepth"] = vars.widepth
js["useprompt"] = vars.useprompt
js["adventure"] = vars.adventure
js["chatmode"] = vars.chatmode
js["chatname"] = vars.chatname
js["dynamicscan"] = vars.dynamicscan
js["nopromptgen"] = vars.nopromptgen
js["rngpersist"] = vars.rngpersist
js["nogenmod"] = vars.nogenmod
js["autosave"] = vars.autosave
js["welcome"] = vars.welcome
js["newlinemode"] = vars.newlinemode
js["antemplate"] = vars.setauthornotetemplate
js["userscripts"] = vars.userscripts
js["corescript"] = vars.corescript
js["softprompt"] = vars.spfilename
# Write it
if not os.path.exists('settings'):
os.mkdir('settings')
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "w")
try:
file.write(json.dumps(js, indent=3))
finally:
file.close()
#==================================================================#
# Don't save settings unless 2 seconds have passed without modification
#==================================================================#
@debounce(2)
def settingschanged():
print("{0}Saving settings!{1}".format(colors.GREEN, colors.END))
savesettings()
#==================================================================#
# Read settings from client file JSON and send to vars
#==================================================================#
@ -550,6 +606,60 @@ def loadsettings():
file.close()
#==================================================================#
# Load a soft prompt from a file
#==================================================================#
def spRequest(filename):
vars.spfilename = ""
settingschanged()
if(len(filename) == 0):
vars.sp = None
vars.sp_length = 0
return
global np
if 'np' not in globals():
import numpy as np
z, version, shape, fortran_order, dtype = fileops.checksp(filename, vars.modeldim)
assert isinstance(z, zipfile.ZipFile)
with z.open('meta.json') as f:
vars.spmeta = json.load(f)
z.close()
with np.load(fileops.sppath(filename), allow_pickle=False) as f:
tensor = f['tensor.npy']
# If the tensor is in bfloat16 format, convert it to float32
if(tensor.dtype == 'V2'):
tensor.dtype = np.uint16
tensor = np.uint32(tensor) << 16
tensor.dtype = np.float32
if(tensor.dtype != np.float16):
tensor = np.float32(tensor)
assert not np.isinf(tensor).any() and not np.isnan(tensor).any()
vars.sp_length = tensor.shape[-2]
vars.spmeta["n_tokens"] = vars.sp_length
if(vars.model in ("TPUMeshTransformerGPTJ",)):
rows = tensor.shape[0]
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
tensor = tensor.reshape(
tpu_mtj_backend.params["cores_per_replica"],
-1,
tpu_mtj_backend.params["d_model"],
)
vars.sp = tpu_mtj_backend.shard_xmap(np.float32(tensor))
else:
vars.sp = torch.from_numpy(tensor)
vars.spfilename = filename
settingschanged()
#==================================================================#
# Startup
#==================================================================#
@ -1298,10 +1408,10 @@ else:
tpu_mtj_backend.compiling_callback = tpumtjgenerate_compiling_callback
tpu_mtj_backend.stopped_compiling_callback = tpumtjgenerate_stopped_compiling_callback
tpu_mtj_backend.settings_callback = tpumtjgenerate_settings_callback
vars.allowsp = True
loadmodelsettings()
loadsettings()
tpu_mtj_backend.load_model(vars.custmodpth, **vars.modelconfig)
vars.allowsp = True
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
tokenizer = tpu_mtj_backend.tokenizer
else:
@ -2397,62 +2507,6 @@ def sendsettings():
if(not frm["id"] in vars.formatoptns):
vars.formatoptns[frm["id"]] = False;
#==================================================================#
# Take settings from vars and write them to client settings file
#==================================================================#
def savesettings():
# Build json to write
js = {}
js["apikey"] = vars.apikey
js["andepth"] = vars.andepth
js["temp"] = vars.temp
js["top_p"] = vars.top_p
js["top_k"] = vars.top_k
js["tfs"] = vars.tfs
js["rep_pen"] = vars.rep_pen
js["rep_pen_slope"] = vars.rep_pen_slope
js["rep_pen_range"] = vars.rep_pen_range
js["genamt"] = vars.genamt
js["max_length"] = vars.max_length
js["ikgen"] = vars.ikgen
js["formatoptns"] = vars.formatoptns
js["numseqs"] = vars.numseqs
js["widepth"] = vars.widepth
js["useprompt"] = vars.useprompt
js["adventure"] = vars.adventure
js["chatmode"] = vars.chatmode
js["chatname"] = vars.chatname
js["dynamicscan"] = vars.dynamicscan
js["nopromptgen"] = vars.nopromptgen
js["rngpersist"] = vars.rngpersist
js["nogenmod"] = vars.nogenmod
js["autosave"] = vars.autosave
js["welcome"] = vars.welcome
js["newlinemode"] = vars.newlinemode
js["antemplate"] = vars.setauthornotetemplate
js["userscripts"] = vars.userscripts
js["corescript"] = vars.corescript
js["softprompt"] = vars.spfilename
# Write it
if not os.path.exists('settings'):
os.mkdir('settings')
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "w")
try:
file.write(json.dumps(js, indent=3))
finally:
file.close()
#==================================================================#
# Don't save settings unless 2 seconds have passed without modification
#==================================================================#
@debounce(2)
def settingschanged():
print("{0}Saving settings!{1}".format(colors.GREEN, colors.END))
savesettings()
#==================================================================#
# Set value of gamesaved
#==================================================================#
@ -4493,60 +4547,6 @@ def loadRequest(loadpath, filename=None):
emit('from_server', {'cmd': 'hidegenseqs', 'data': ''}, broadcast=True)
print("{0}Story loaded from {1}!{2}".format(colors.GREEN, filename, colors.END))
#==================================================================#
# Load a soft prompt from a file
#==================================================================#
def spRequest(filename):
vars.spfilename = ""
settingschanged()
if(len(filename) == 0):
vars.sp = None
vars.sp_length = 0
return
global np
if 'np' not in globals():
import numpy as np
z, version, shape, fortran_order, dtype = fileops.checksp(filename, vars.modeldim)
assert isinstance(z, zipfile.ZipFile)
with z.open('meta.json') as f:
vars.spmeta = json.load(f)
z.close()
with np.load(fileops.sppath(filename), allow_pickle=False) as f:
tensor = f['tensor.npy']
# If the tensor is in bfloat16 format, convert it to float32
if(tensor.dtype == 'V2'):
tensor.dtype = np.uint16
tensor = np.uint32(tensor) << 16
tensor.dtype = np.float32
if(tensor.dtype != np.float16):
tensor = np.float32(tensor)
assert not np.isinf(tensor).any() and not np.isnan(tensor).any()
vars.sp_length = tensor.shape[-2]
vars.spmeta["n_tokens"] = vars.sp_length
if(vars.model in ("TPUMeshTransformerGPTJ",)):
rows = tensor.shape[0]
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
tensor = tensor.reshape(
tpu_mtj_backend.params["cores_per_replica"],
-1,
tpu_mtj_backend.params["d_model"],
)
vars.sp = tpu_mtj_backend.shard_xmap(np.float32(tensor))
else:
vars.sp = torch.from_numpy(tensor)
vars.spfilename = filename
settingschanged()
#==================================================================#
# Import an AIDungon game exported with Mimi's tool
#==================================================================#