Allow TPU models to specify settings/config in config.json

This commit is contained in:
Gnome Ann
2022-02-23 18:22:18 -05:00
parent 6151d16df0
commit ad10ac8871
2 changed files with 27 additions and 7 deletions

View File

@ -134,6 +134,7 @@ class vars:
wifolders_d = {} # Dictionary of World Info folder UID-info pairs
wifolders_l = [] # List of World Info folder UIDs
wifolders_u = {} # Dictionary of pairs of folder UID - list of WI UID
modelconfig = {} # Raw contents of the model's config.json, or empty dictionary if none found
lua_state = None # Lua state of the Lua scripting system
lua_koboldbridge = None # `koboldbridge` from bridge.lua
lua_kobold = None # `kobold` from` bridge.lua
@ -417,10 +418,16 @@ def loadmodelsettings():
js = json.loads(model_js_config)
except Exception as e:
try:
model_js_config = open(vars.custmodpth + "/config.json", "r")
try:
model_js_config = open(vars.custmodpth + "/config.json", "r")
except Exception as e:
model_js_config = open(vars.custmodpth.replace('/', '_') + "/config.json", "r")
js = json.load(model_js_config)
except Exception as e:
model_js_config = open(vars.custmodpth.replace('/', '_') + "/config.json", "r")
js = json.load(model_js_config)
js = {}
if vars.model_type == "xglm" or js.get("modelcompat", "j") == "fairseq_lm":
vars.newlinemode = "s" # Default to </s> newline mode if using XGLM
vars.modelconfig = js
if("badwordsids" in js):
vars.badwordsids = js["badwordsids"]
if("nobreakmodel" in js):
@ -1192,7 +1199,8 @@ else:
# Load the TPU backend if requested
elif(vars.model == "TPUMeshTransformerGPTJ"):
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth)
if not vars.custmodpth or not os.path.isdir(vars.custmodpth):
raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder")
import tpu_mtj_backend
tpu_mtj_backend.vars = vars
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback
@ -1200,7 +1208,8 @@ else:
tpu_mtj_backend.compiling_callback = tpumtjgenerate_compiling_callback
tpu_mtj_backend.stopped_compiling_callback = tpumtjgenerate_stopped_compiling_callback
tpu_mtj_backend.settings_callback = tpumtjgenerate_settings_callback
tpu_mtj_backend.load_model(vars.custmodpth)
loadmodelsettings()
tpu_mtj_backend.load_model(vars.custmodpth, **vars.modelconfig)
vars.allowsp = True
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
tokenizer = tpu_mtj_backend.tokenizer
@ -4921,7 +4930,7 @@ if(vars.model in ("TPUMeshTransformerGPTJ",)):
def send_debug():
if vars.debug:
debug_info = ""
for variable in [["Action Length", len(vars.actions)], ["Actions Metadata Length", len(vars.actions_metadata)], ["Actions Metadata", vars.actions_metadata], ["Newline Mode", vars.newlinemode]]:
for variable in [["Newline Mode", vars.newlinemode], ["Action Length", len(vars.actions)], ["Actions Metadata Length", len(vars.actions_metadata)], ["Actions Metadata", vars.actions_metadata]]:
debug_info = "{}{}: {}\n".format(debug_info, variable[0], variable[1])
emit('from_server', {'cmd': 'debug_info', 'data': debug_info}, broadcast=True)