diff --git a/aiserver.py b/aiserver.py index 35131612..681eacb2 100644 --- a/aiserver.py +++ b/aiserver.py @@ -134,6 +134,7 @@ class vars: wifolders_d = {} # Dictionary of World Info folder UID-info pairs wifolders_l = [] # List of World Info folder UIDs wifolders_u = {} # Dictionary of pairs of folder UID - list of WI UID + modelconfig = {} # Raw contents of the model's config.json, or empty dictionary if none found lua_state = None # Lua state of the Lua scripting system lua_koboldbridge = None # `koboldbridge` from bridge.lua lua_kobold = None # `kobold` from` bridge.lua @@ -417,10 +418,16 @@ def loadmodelsettings(): js = json.loads(model_js_config) except Exception as e: try: - model_js_config = open(vars.custmodpth + "/config.json", "r") + try: + model_js_config = open(vars.custmodpth + "/config.json", "r") + except Exception as e: + model_js_config = open(vars.custmodpth.replace('/', '_') + "/config.json", "r") + js = json.load(model_js_config) except Exception as e: - model_js_config = open(vars.custmodpth.replace('/', '_') + "/config.json", "r") - js = json.load(model_js_config) + js = {} + if vars.model_type == "xglm" or js.get("modelcompat", "j") == "fairseq_lm": + vars.newlinemode = "s" # Default to newline mode if using XGLM + vars.modelconfig = js if("badwordsids" in js): vars.badwordsids = js["badwordsids"] if("nobreakmodel" in js): @@ -1192,7 +1199,8 @@ else: # Load the TPU backend if requested elif(vars.model == "TPUMeshTransformerGPTJ"): print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END)) - assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth) + if not vars.custmodpth or not os.path.isdir(vars.custmodpth): + raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder") import tpu_mtj_backend tpu_mtj_backend.vars = vars tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback @@ -1200,7 +1208,8 @@ else: tpu_mtj_backend.compiling_callback = tpumtjgenerate_compiling_callback tpu_mtj_backend.stopped_compiling_callback = tpumtjgenerate_stopped_compiling_callback tpu_mtj_backend.settings_callback = tpumtjgenerate_settings_callback - tpu_mtj_backend.load_model(vars.custmodpth) + loadmodelsettings() + tpu_mtj_backend.load_model(vars.custmodpth, **vars.modelconfig) vars.allowsp = True vars.modeldim = int(tpu_mtj_backend.params["d_model"]) tokenizer = tpu_mtj_backend.tokenizer @@ -4921,7 +4930,7 @@ if(vars.model in ("TPUMeshTransformerGPTJ",)): def send_debug(): if vars.debug: debug_info = "" - for variable in [["Action Length", len(vars.actions)], ["Actions Metadata Length", len(vars.actions_metadata)], ["Actions Metadata", vars.actions_metadata], ["Newline Mode", vars.newlinemode]]: + for variable in [["Newline Mode", vars.newlinemode], ["Action Length", len(vars.actions)], ["Actions Metadata Length", len(vars.actions_metadata)], ["Actions Metadata", vars.actions_metadata]]: debug_info = "{}{}: {}\n".format(debug_info, variable[0], variable[1]) emit('from_server', {'cmd': 'debug_info', 'data': debug_info}, broadcast=True) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 00f9a510..a78f93f2 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -791,12 +791,24 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) "pe_rotary_dims": 64, "seq": 2048, "cores_per_replica": 8, + "tokenizer_class": "GPT2TokenizerFast", + "tokenizer": "gpt2", } params = kwargs + if "compat" in params: + default_params["compat"] = params["compat"] + if default_params["compat"] == "fairseq_lm": + default_params["tokenizer"] = "KoboldAI/fairseq-dense-125M" for param in default_params: if param not in params: params[param] = default_params[param] + # Load tokenizer + if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")): + raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'") + tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"]) + tokenizer = tokenizer_class.from_pretrained(params["tokenizer"]) + # Disable JAX warnings about these two functions having been renamed jax.host_count = jax.process_count jax.host_id = jax.process_index @@ -819,7 +831,6 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape) thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ()) maps.thread_resources.env = thread_resources_env - tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') global shard_xmap, batch_xmap shard_xmap = __shard_xmap()